{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import Parameter\n",
    "import torch.optim as optim\n",
    "\n",
    "from typing import *\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA_ROOT = Path(\"../data/brown\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_EPOCHS = 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building an LSTM from scratch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we'll be building our own LSTM and delving into why it performs so well across a wide range of tasks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# The Basics of the LSTM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we actually build the LSTM, we'll need to understand its basic mechansim."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The below diagram shows the flow of information in an LSTM cell (image from Wikipedia)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://upload.wikimedia.org/wikipedia/commons/thumb/3/3b/The_LSTM_cell.png/1920px-The_LSTM_cell.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The equation for the LSTM looks like this:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\\begin{array}{ll} \\\\\n",
    "            i_t = \\sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{(t-1)} + b_{hi}) \\\\\n",
    "            f_t = \\sigma(W_{if} x_t + b_{if} + W_{hf} h_{(t-1)} + b_{hf}) \\\\\n",
    "            g_t = \\tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{(t-1)} + b_{hg}) \\\\\n",
    "            o_t = \\sigma(W_{io} x_t + b_{io} + W_{ho} h_{(t-1)} + b_{ho}) \\\\\n",
    "            c_t = f_t * c_{(t-1)} + i_t * g_t \\\\\n",
    "            h_t = o_t * \\tanh(c_t) \\\\\n",
    "        \\end{array}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It seems complex, but when you pick it apart, the LSTM is actually very simple. The core of the LSTM is the following equation:\n",
    "\n",
    "\\begin{array}{ll} \\\\\n",
    "            c_t = f_t * c_{(t-1)} + i_t * g_t \\\\\n",
    "\\end{array}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's pick this equation apart: $ c_t $ is the new cell state, which is basically the memory of the LSTM. \n",
    "\n",
    "$ f_t $ is called the \"forget gate\": it dictates how much of the previous cell state to **retain** (but is slightly confusingly named the forget gate). \n",
    "\n",
    "$ i_t $ is the \"input gate\" and dictates how much to update the cell state with new information.\n",
    "\n",
    "Finally, $ g_t $ is the information we use to update the cell state.\n",
    "\n",
    "Basically, an LSTM chooses to keep a certain portion of its previous cell state and add a certain amount of new information. These proportions are controlled using gates."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's contrast this update rule with the update rule of a simpler RNN"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "$$ c_t = \\tanh(W_hc_{t-1} + W_ix_t) $$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "(To make the contrast clearer, I'm representing the hidden state of the RNN as $ c_t $.)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, there is a huge difference between the simple RNN's update rule and the LSTM's update rule. Whereas the RNN computes the new hidden state from scratch based on the previous hidden state and the input, the LSTM computes the new hidden state by choosing what to **add** to the current state. This is similar to how ResNets learn: they learn what to add to the current state/block instead of directly learning the new state. In other words, LSTMs are great primarily because they are **additive**. We'll formalize this intuition later when we examine the gradient flow, but this is the basic idea behind the LSTM.\n",
    "\n",
    "Now that we have a basic understanding, let's start coding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Side Note: One thing that is slightly confusing about the LSTM is that it has two \"hidden states\": $ c_t $ and $ h_t $. Intuitively, $ c_t $ is the \"internal\" hidden state that retains important information for longer timesteps, whereas $ h_t $ is the \"external\" hidden state that exposes that information to the outside world.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Side Note: If you're looking carefully, you'll notice that the bias terms are redundant. The reason they are there is for compatibility with the CuDNN backend. Until we touch on CuDNN, we'll use a single bias term."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Implementing the LSTM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll be using PyTorch to write our own LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from enum import IntEnum\n",
    "class Dim(IntEnum):\n",
    "    batch = 0\n",
    "    seq = 1\n",
    "    feature = 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NaiveLSTM(nn.Module):\n",
    "    def __init__(self, input_sz: int, hidden_sz: int):\n",
    "        super().__init__()\n",
    "        self.input_size = input_sz\n",
    "        self.hidden_size = hidden_sz\n",
    "        # input gate\n",
    "        self.W_ii = Parameter(torch.Tensor(input_sz, hidden_sz))\n",
    "        self.W_hi = Parameter(torch.Tensor(hidden_sz, hidden_sz))\n",
    "        self.b_i = Parameter(torch.Tensor(hidden_sz))\n",
    "        # forget gate\n",
    "        self.W_if = Parameter(torch.Tensor(input_sz, hidden_sz))\n",
    "        self.W_hf = Parameter(torch.Tensor(hidden_sz, hidden_sz))\n",
    "        self.b_f = Parameter(torch.Tensor(hidden_sz))\n",
    "        # ???\n",
    "        self.W_ig = Parameter(torch.Tensor(input_sz, hidden_sz))\n",
    "        self.W_hg = Parameter(torch.Tensor(hidden_sz, hidden_sz))\n",
    "        self.b_g = Parameter(torch.Tensor(hidden_sz))\n",
    "        # output gate\n",
    "        self.W_io = Parameter(torch.Tensor(input_sz, hidden_sz))\n",
    "        self.W_ho = Parameter(torch.Tensor(hidden_sz, hidden_sz))\n",
    "        self.b_o = Parameter(torch.Tensor(hidden_sz))\n",
    "        \n",
    "        self.init_weights()\n",
    "    \n",
    "    def init_weights(self):\n",
    "        for p in self.parameters():\n",
    "            if p.data.ndimension() >= 2:\n",
    "                nn.init.xavier_uniform_(p.data)\n",
    "            else:\n",
    "                nn.init.zeros_(p.data)\n",
    "        \n",
    "    def forward(self, x: torch.Tensor, \n",
    "                init_states: Optional[Tuple[torch.Tensor, torch.Tensor]]=None\n",
    "               ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n",
    "        \"\"\"Assumes x is of shape (batch, sequence, feature)\"\"\"\n",
    "        bs, seq_sz, _ = x.size()\n",
    "        hidden_seq = []\n",
    "        if init_states is None:\n",
    "            h_t, c_t = torch.zeros(self.hidden_size).to(x.device), torch.zeros(self.hidden_size).to(x.device)\n",
    "        else:\n",
    "            h_t, c_t = init_states\n",
    "        for t in range(seq_sz): # iterate over the time steps\n",
    "            x_t = x[:, t, :]\n",
    "            i_t = torch.sigmoid(x_t @ self.W_ii + h_t @ self.W_hi + self.b_i)\n",
    "            f_t = torch.sigmoid(x_t @ self.W_if + h_t @ self.W_hf + self.b_f)\n",
    "            g_t = torch.tanh(x_t @ self.W_ig + h_t @ self.W_hg + self.b_g)\n",
    "            o_t = torch.sigmoid(x_t @ self.W_io + h_t @ self.W_ho + self.b_o)\n",
    "            c_t = f_t * c_t + i_t * g_t\n",
    "            h_t = o_t * torch.tanh(c_t)\n",
    "            hidden_seq.append(h_t.unsqueeze(Dim.batch))\n",
    "        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)\n",
    "        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)\n",
    "        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()\n",
    "        return hidden_seq, (h_t, c_t)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Testing on some synthetic data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "bs, seq_len, feat_sz, hidden_sz = 5, 10, 32, 16\n",
    "arr = torch.randn(bs, seq_len, feat_sz)\n",
    "lstm = NaiveLSTM(feat_sz, hidden_sz)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "hs, (hn, cn) = lstm(arr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 10, 16])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It looks like it works!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Testing our implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, that we've covered the basics and have a minimally working LSTM, we'll put our model into action. Our testbed will be a character-level language modeling task. We'll be using the Brown Corpus which you can get via the commands below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p {DATA_ROOT}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current\r\n",
      "                                 Dload  Upload   Total   Spent    Left  Speed\r\n",
      "\r",
      "  0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0curl: (6) Could not resolve host: www.sls.hawaii.edu\r\n"
     ]
    }
   ],
   "source": [
    "!curl http://www.sls.hawaii.edu/bley-vroman/brown.txt -o {DATA_ROOT / \"brown.txt\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll let AllenNLP handle the complexity of training the language model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]02/15/2019 16:16:55 - INFO - allennlp.data.dataset_readers.language_modeling -   Creating dataset from all text in file: ../data/brown/brown.txt\n",
      "\n",
      "  0%|          | 0/11994 [00:00<?, ?it/s]\u001b[A\n",
      " 47%|████▋     | 5598/11994 [00:00<00:00, 55979.39it/s]\u001b[A\n",
      " 80%|███████▉  | 9561/11994 [00:00<00:00, 46578.75it/s]\u001b[A\n",
      "11994it [00:12, 987.71it/s]4 [00:00<00:00, 45664.13it/s]\u001b[A\n",
      "02/15/2019 16:16:58 - INFO - allennlp.data.vocabulary -   Fitting token dictionary from dataset.\n",
      "100%|██████████| 10794/10794 [00:05<00:00, 1917.79it/s]\n"
     ]
    }
   ],
   "source": [
    "from allennlp.data.dataset_readers import LanguageModelingReader\n",
    "from allennlp.data.tokenizers import CharacterTokenizer\n",
    "from allennlp.data.token_indexers import SingleIdTokenIndexer\n",
    "from allennlp.data import Vocabulary\n",
    "from allennlp.data.iterators import BasicIterator\n",
    "from allennlp.training import Trainer\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "char_tokenizer = CharacterTokenizer(lowercase_characters=True)\n",
    "\n",
    "reader = LanguageModelingReader(\n",
    "    tokens_per_instance=500,\n",
    "    tokenizer=char_tokenizer,\n",
    "    token_indexers = {\"tokens\": SingleIdTokenIndexer()},\n",
    ")\n",
    "\n",
    "train_ds = reader.read(DATA_ROOT / \"brown.txt\")\n",
    "train_ds, val_ds = train_test_split(train_ds, random_state=0, test_size=0.1)\n",
    "\n",
    "vocab = Vocabulary.from_instances(train_ds)\n",
    "\n",
    "iterator = BasicIterator(batch_size=32)\n",
    "iterator.index_with(vocab)\n",
    "\n",
    "def train(model: nn.Module, epochs: int=10):\n",
    "    trainer = Trainer(\n",
    "        model=model.cuda() if torch.cuda.is_available() else model,\n",
    "        optimizer=optim.Adam(model.parameters()),\n",
    "        iterator=iterator, train_dataset=train_ds, \n",
    "        validation_dataset=val_ds, num_epochs=epochs,\n",
    "        cuda_device=0 if torch.cuda.is_available() else -1\n",
    "    )\n",
    "    return trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from allennlp.modules.seq2seq_encoders import PytorchSeq2SeqWrapper\n",
    "from allennlp.modules.token_embedders import Embedding\n",
    "from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder\n",
    "from allennlp.models import Model\n",
    "from allennlp.nn.util import get_text_field_mask\n",
    "\n",
    "class LanguageModel(Model):\n",
    "    def __init__(self, encoder: nn.RNN, vocab: Vocabulary,\n",
    "                 embedding_dim: int=50):\n",
    "        super().__init__(vocab=vocab)\n",
    "        # char embedding\n",
    "        self.vocab_size = vocab.get_vocab_size()\n",
    "        self.padding_idx = vocab.get_token_index(\"@@PADDING@@\")\n",
    "        token_embedding = Embedding(\n",
    "            num_embeddings=vocab.get_vocab_size(),\n",
    "            embedding_dim=embedding_dim,\n",
    "            padding_index=self.padding_idx,\n",
    "        )\n",
    "        self.embedding = BasicTextFieldEmbedder({\"tokens\": token_embedding})\n",
    "        self.encoder = encoder\n",
    "        self.projection = nn.Linear(self.encoder.hidden_size, self.vocab_size)\n",
    "        self.loss = nn.CrossEntropyLoss(ignore_index=self.padding_idx)\n",
    "    \n",
    "    def forward(self, input_tokens: Dict[str, torch.Tensor],\n",
    "                output_tokens: Dict[str, torch.Tensor]):\n",
    "        # TODO: Implement\n",
    "        embs = self.embedding(input_tokens)\n",
    "        x, _ = self.encoder(embs)\n",
    "        x = self.projection(x)\n",
    "        if output_tokens is not None:\n",
    "            loss = self.loss(x.view((-1, self.vocab_size)), output_tokens[\"tokens\"].flatten())\n",
    "        else:\n",
    "            loss = None\n",
    "        return {\"loss\": loss, \"logits\": x}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's try training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "02/15/2019 16:17:04 - WARNING - allennlp.training.trainer -   You provided a validation dataset but patience was set to None, meaning that early stopping is disabled\n",
      "02/15/2019 16:17:04 - INFO - allennlp.training.trainer -   Beginning training.\n",
      "02/15/2019 16:17:04 - INFO - allennlp.training.trainer -   Epoch 0/0\n",
      "02/15/2019 16:17:04 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1729.343488\n",
      "02/15/2019 16:17:04 - INFO - allennlp.training.trainer -   Training\n",
      "loss: 2.6929 ||: 100%|██████████| 338/338 [17:04<00:00,  2.50s/it]\n",
      "02/15/2019 16:34:09 - INFO - allennlp.training.trainer -   Validating\n",
      "loss: 2.3200 ||: 100%|██████████| 38/38 [00:10<00:00,  4.15it/s]\n",
      "02/15/2019 16:34:19 - INFO - allennlp.training.trainer -                     Training |  Validation\n",
      "02/15/2019 16:34:19 - INFO - allennlp.training.trainer -   cpu_memory_MB |  1729.343  |       N/A\n",
      "02/15/2019 16:34:19 - INFO - allennlp.training.trainer -   loss          |     2.693  |     2.320\n",
      "02/15/2019 16:34:19 - INFO - allennlp.training.trainer -   Epoch duration: 00:17:14\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'peak_cpu_memory_MB': 1729.343488,\n",
       " 'training_duration': '00:17:14',\n",
       " 'training_start_epoch': 0,\n",
       " 'training_epochs': 0,\n",
       " 'epoch': 0,\n",
       " 'training_loss': 2.692909450926019,\n",
       " 'training_cpu_memory_MB': 1729.343488,\n",
       " 'validation_loss': 2.3199655068548104,\n",
       " 'best_epoch': 0,\n",
       " 'best_validation_loss': 2.3199655068548104}"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lm_naive = LanguageModel(NaiveLSTM(50, 125), vocab)\n",
    "train(lm_naive, epochs=N_EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's compare with the official LSTM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "02/15/2019 16:34:19 - WARNING - allennlp.training.trainer -   You provided a validation dataset but patience was set to None, meaning that early stopping is disabled\n",
      "02/15/2019 16:34:19 - INFO - allennlp.training.trainer -   Beginning training.\n",
      "02/15/2019 16:34:19 - INFO - allennlp.training.trainer -   Epoch 0/0\n",
      "02/15/2019 16:34:19 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1828.790272\n",
      "02/15/2019 16:34:19 - INFO - allennlp.training.trainer -   Training\n",
      "loss: 2.6951 ||: 100%|██████████| 338/338 [03:34<00:00,  1.71it/s]\n",
      "02/15/2019 16:37:53 - INFO - allennlp.training.trainer -   Validating\n",
      "loss: 2.3065 ||: 100%|██████████| 38/38 [00:07<00:00,  6.44it/s]\n",
      "02/15/2019 16:38:01 - INFO - allennlp.training.trainer -                     Training |  Validation\n",
      "02/15/2019 16:38:01 - INFO - allennlp.training.trainer -   cpu_memory_MB |  1828.790  |       N/A\n",
      "02/15/2019 16:38:01 - INFO - allennlp.training.trainer -   loss          |     2.695  |     2.307\n",
      "02/15/2019 16:38:01 - INFO - allennlp.training.trainer -   Epoch duration: 00:03:41\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'peak_cpu_memory_MB': 1828.790272,\n",
       " 'training_duration': '00:03:41',\n",
       " 'training_start_epoch': 0,\n",
       " 'training_epochs': 0,\n",
       " 'epoch': 0,\n",
       " 'training_loss': 2.695129863609224,\n",
       " 'training_cpu_memory_MB': 1828.790272,\n",
       " 'validation_loss': 2.30654828171981,\n",
       " 'best_epoch': 0,\n",
       " 'best_validation_loss': 2.30654828171981}"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lm_comparison = LanguageModel(nn.LSTM(50, 125, batch_first=True), vocab)\n",
    "train(lm_comparison, epochs=N_EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Out model is a lot slower, but we're getting similar performance, so it looks good! We'll look at how we can optimize the performance later."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's compare the performance of the LSTM with a much simpler RNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleRNN(nn.Module):\n",
    "    def __init__(self, input_sz: int, hidden_sz: int):\n",
    "        super().__init__()\n",
    "        self.input_sz, self.hidden_size = input_sz, hidden_sz\n",
    "        self.weight_ih = Parameter(torch.Tensor(input_sz, hidden_sz))\n",
    "        self.weight_hh = Parameter(torch.Tensor(hidden_sz, hidden_sz))\n",
    "        self.bias_hh = Parameter(torch.Tensor(hidden_sz))\n",
    "        \n",
    "        self.init_weights()\n",
    "\n",
    "    def init_weights(self):\n",
    "        nn.init.xavier_uniform_(self.weight_ih)\n",
    "        nn.init.xavier_uniform_(self.weight_hh)\n",
    "        nn.init.zeros_(self.bias_hh)\n",
    "    \n",
    "    def forward(self, x: torch.Tensor, init_state=None) -> torch.Tensor:\n",
    "        \"\"\"Assumes x is of shape (batch, sequence, feature)\"\"\"\n",
    "        bs, seq_sz, _ = x.size()\n",
    "        hidden_seq = []\n",
    "        if init_state is None:\n",
    "            h_t = torch.zeros(self.hidden_size).to(x.device)\n",
    "        else:\n",
    "            h_t = init_state\n",
    "\n",
    "        for t in range(seq_sz):\n",
    "            x_t = x[:, t, :]\n",
    "            h_t = torch.tanh(x_t @ self.weight_ih + h_t @ self.weight_hh + self.bias_hh)\n",
    "            hidden_seq.append(h_t.unsqueeze(Dim.batch))\n",
    "        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)\n",
    "        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)\n",
    "        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()\n",
    "        return hidden_seq, h_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "02/15/2019 16:38:01 - WARNING - allennlp.training.trainer -   You provided a validation dataset but patience was set to None, meaning that early stopping is disabled\n",
      "02/15/2019 16:38:01 - INFO - allennlp.training.trainer -   Beginning training.\n",
      "02/15/2019 16:38:01 - INFO - allennlp.training.trainer -   Epoch 0/0\n",
      "02/15/2019 16:38:01 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1905.47968\n",
      "02/15/2019 16:38:01 - INFO - allennlp.training.trainer -   Training\n",
      "loss: 2.5764 ||: 100%|██████████| 338/338 [13:24<00:00,  1.95s/it]\n",
      "02/15/2019 16:51:26 - INFO - allennlp.training.trainer -   Validating\n",
      "loss: 2.2059 ||: 100%|██████████| 38/38 [00:03<00:00, 13.33it/s]\n",
      "02/15/2019 16:51:29 - INFO - allennlp.training.trainer -                     Training |  Validation\n",
      "02/15/2019 16:51:29 - INFO - allennlp.training.trainer -   cpu_memory_MB |  1905.480  |       N/A\n",
      "02/15/2019 16:51:29 - INFO - allennlp.training.trainer -   loss          |     2.576  |     2.206\n",
      "02/15/2019 16:51:29 - INFO - allennlp.training.trainer -   Epoch duration: 00:13:27\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'peak_cpu_memory_MB': 1905.47968,\n",
       " 'training_duration': '00:13:27',\n",
       " 'training_start_epoch': 0,\n",
       " 'training_epochs': 0,\n",
       " 'epoch': 0,\n",
       " 'training_loss': 2.576429249266901,\n",
       " 'training_cpu_memory_MB': 1905.47968,\n",
       " 'validation_loss': 2.205946420368395,\n",
       " 'best_epoch': 0,\n",
       " 'best_validation_loss': 2.205946420368395}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lm_simplernn = LanguageModel(SimpleRNN(50, 125), vocab)\n",
    "train(lm_simplernn, epochs=N_EPOCHS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Understanding the dynamics of LSTM learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Why exactly do LSTMs learn so well? Let's analyze the dynamics of LSTM learning by checking how the gradients change and comparing them to the gradients of a simple RNN."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_batch = next(iterator(train_ds))\n",
    "test_embeddings = lm_naive.embedding(test_batch[\"input_tokens\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The gradient dynamics of simple RNNs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, let's check how the gradients of a simple RNN change with regards to the initial inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnn = SimpleRNN(50, 125)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rnn_step(x_t, h_t, weight_ih, weight_hh, bias_hh):\n",
    "    return torch.tanh(x_t @ weight_ih + h_t @ weight_hh + bias_hh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "h_0 = torch.zeros(rnn.hidden_size, requires_grad=True).to(test_embeddings.device)\n",
    "h_t = h_0\n",
    "grads = []\n",
    "\n",
    "for t in range(100):\n",
    "    h_t = rnn_step(\n",
    "        test_embeddings[:, t, :], h_t,\n",
    "        rnn.weight_ih, rnn.weight_hh, rnn.bias_hh,\n",
    "    )\n",
    "    loss = h_t.abs().sum() # we'll use the l1 norm of the current hidden state as the loss\n",
    "    loss.backward(retain_graph=True)\n",
    "    grads.append(torch.norm(h_0.grad).item())\n",
    "    h_0.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1a7bf1f400>]"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD8CAYAAAB5Pm/hAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xl4XPV97/H3dxZpJFm7JXm3bDA4xgkBBJgQEiAkgTYJNOUm5NKEpKTuQpa2aUlo+9w2z21uk9vepsmTloQAgbQJWQgNJCULIQuhgME2m40JXvAi27JkW6u1zsz3/jFnZNmWLXkWjTT6vJ7Hj2bOnJn5nTnwmd98z+/8jrk7IiJSvEKFboCIiOSXgl5EpMgp6EVEipyCXkSkyCnoRUSKnIJeRKTIKehFRIqcgl5EpMgp6EVEilxkohXM7G7gHUC7u68es/yjwC1AAvgvd781WH4bcHOw/GPu/pOJ3mPu3Lne3Nyc0QaIiMxWGzZsOOjuDROtN2HQA/cAXwK+nl5gZlcA1wLnuvuQmTUGy1cBNwDnAAuAn5nZWe6eONUbNDc3s379+kk0RURE0sxs12TWm7B04+6PAYePW/zHwGfdfShYpz1Yfi3wLXcfcvdXgW3ARZNutYiI5FymNfqzgMvMbJ2Z/crMLgyWLwT2jFmvNVgmIiIFMpnSzcmeVwesAS4EvmNmy0/nBcxsLbAWYMmSJRk2Q0REJpJpj74VeMBTngaSwFxgL7B4zHqLgmUncPc73L3F3VsaGiY8liAiIhnKNOi/D1wBYGZnASXAQeAh4AYzKzWzZcAK4OlcNFRERDIzmeGV9wGXA3PNrBX4W+Bu4G4z2wQMAzd56gomm83sO8BLQBy4ZaIRNyIikl82Ha4w1dLS4hpeKSJyesxsg7u3TLSezozNgf3dA/x0c1uhmyEiMi4FfQ7c/fir/NF/bGA4nix0U0RETqCgz4Fdh/pJOnT0DRW6KSIiJ1DQ50Br5wAAB3oGC9wSEZETKehzYE9nPwDtCnoRmYYU9FnqHhihdzAOwIEelW5EZPpR0Gdpz+H+0dsq3YjIdKSgz1K6Pg/q0YvI9KSgz1JrUJ9fPreC9l716EVk+lHQZ6m1c4DK0ghnNs5R6UZEpiUFfZb2HO5nUV0586pjKt2IyLSkoM9Sa+cAi2rLaKqK0T0wwuCI5nATkelFQZ8Fd2dPZz+La8tprCwFjh150zs4wref2c10mDhORGYvBX0WDh8Zpn84Mdqjh2NH3jywcS+f/N6LvNzWW6gmiogo6LORHlp5bNAf7dH/5kAq4HcdOjL1jRMRCSjos5Ce+mBxXTlNVSeWbrYGQb/zUP+JTxYRmSKZXhxcOLZHP6c0QkkkRHtvqnTj7rxyoA9IzW4pIlIoE/bozexuM2sPLht4/GOfMDM3s7nBfTOzL5rZNjN7wczOz0ejp4s9h/upKY9SGYtiZjRVlY726Dt6h+geGAFUuhGRwppM6eYe4OrjF5rZYuBtwO4xi68hdUHwFcBa4Pbsmzh9pYdWpjVVxkaDPl2fn1cVU49eRApqwqB398eAw+M89HngVmDs2MFrga97ylNAjZnNz0lLp6H00Mq0pqoY7cGom3TZ5qpVjezrHmAorvH1IlIYGR2MNbNrgb3u/vxxDy0E9oy53xosG+811prZejNb39HRkUkzCsrd2Xtcj75xTOlm64FeasujXLC0FnfYc3jgZC8lIpJXpx30ZlYO/BXwv7J5Y3e/w91b3L2loaEhm5cqiI7eIYbiSRbXHe3Rz6uKcWQ4Qd9QnFcO9HJWUyVL6ysA1elFpHAy6dGfASwDnjezncAiYKOZzQP2AovHrLsoWFZ09owZcZOWHkvf1j3I1gN9nNVUSXMQ9BpiKSKFctpB7+4vunujuze7ezOp8sz57t4GPAR8IBh9swbodvf9uW3y9JCennhsjb4xGEv/QmsXvUNxzmqaQ215lMpYRD16ESmYyQyvvA94EjjbzFrN7OZTrP4wsAPYBnwV+JOctHIaSo+hXzhOj/7xrQcBWNFUiZmxtL5cPXoRKZgJT5hy9/dN8HjzmNsO3JJ9s6a/7R19NFSWUl5y9CNMB/2vt6WC/qymSgCW1leweW/31DdSRARNgZCxDbs6OW9xzTHL5pRGqCgJ09E7xNw5JdRVlADQXF9Oa+cAI4lkIZoqIrOcgj4D7b2D7DrUz4XNdSc8lu7Vp3vzkOrRx5POvi4NsRSRqaegz8CGnZ0AXNBce8Jj6QOyY4NeI29EpJAU9Bl4ZmcnpZEQqxdUn/BYuke/omnO6LKl9amRObs18kZECkBBn4ENuw7z+sU1lERO/PjGK900VpYSi4bUoxeRglDQn6b+4Tib9vXQMk7ZBuDMhjmURcPHBL2Z0VxfobH0IlIQmo/+ND23u4tE0mkZ50AswO9esIjLVzZQXRY9ZvnS+nK2dyjoRWTqqUd/mtbv6sQMzl8yfo8+HDIaK2MnLF9aX8Huw/0kk7pQuIhMLQX9aXpm52HObqo8occ+kaX15QzHk7SNudSgiMhUUNCfhkTSeXZ310nr86dyZkNqFM7LbT25bpaIyCkp6E/Dy2099A3FaVk6fn3+VF63qIZIyFgfjMEXEZkqCvrTkA7pTHr0ZSVhzllQxfpdCnoRmVoK+tPw3J4umqpKWVhTNvHK47hgaR3P7+liOK45b0Rk6ijoT8OhI8PMry7DzDJ6fktzLUPxJJv3aSZLEZk6CvrT0D0wQtVpjrYZq2VpquSzQeUbEZlCCvrT0DswQlUs83PMGqtiLK4r0wFZEZlSk7nC1N1m1m5mm8Ys+0cze9nMXjCz/zSzmjGP3WZm28zsN2b29nw1vBB6BkdOe/z88VqW1rF+12FS12gREcm/yfTo7wGuPm7ZI8Bqd38d8ApwG4CZrQJuAM4JnvNvZhbOWWsLyN2zLt1Aqk5/sG+YXZrgTESmyIRB7+6PAYePW/ZTd48Hd58CFgW3rwW+5e5D7v4qqWvHXpTD9hbM4EiSkYRTFcu+Rw9omKWITJlc1Oh/H/hRcHshsGfMY63BshmvZ3AEgKqy7OaBW9E4h6pYhA27Dk+8sohIDmQV9Gb210Ac+EYGz11rZuvNbH1HR0c2zZgS3QOpoM+2Rh8KGecvrdUBWRGZMhkHvZl9EHgHcKMfPbK4F1g8ZrVFwbITuPsd7t7i7i0NDQ2ZNmPK9ARBn23pBlLDLLe299HVP5z1a4mITCSjoDezq4FbgXe5+9ijig8BN5hZqZktA1YAT2ffzMI7WrrJPugvCOr0z+7uyvq1REQmMpnhlfcBTwJnm1mrmd0MfAmoBB4xs+fM7MsA7r4Z+A7wEvBj4BZ3T+St9VMoV6UbgNULqwB0hqyITIkJjyy6+/vGWXzXKdb/DPCZbBo1HfUMpAYZZXPCVFplLMrS+nJe2q8pi0Uk/3Rm7CSla/SVOajRA6yaX8XmfQp6Eck/Bf0k9QyOUBYNUxLJzUd2zoIqdh3qpzeo/YuI5IuCfpK6B7Kf/mCsVQtSdfot+3tz9poiIuNR0E9Sz0A865OlxjpnQTWgA7Iikn8K+knqGRzJyRj6tMbKUubOKeEl1elFJM8U9JOU69KNmfEaHZAVkSmgoJ+knsHsZ6483jkLqtna3qtLC4pIXinoJ6lnIJ6TMfRjnbOgipGEs7VdB2RFJH8U9JOQTHpeevTpkTeq04tIPinoJ6FvOI57bqY/GKu5voLykrDq9CKSVwr6ScjlzJVjhUPGynmV6tGLSF4p6CdhdJ6bHI6jTztnQTUv7e8hmdQ1ZEUkPxT0k5CeuTLXNXpIHZDtG4qzp1PXkBWR/FDQT8LoXPQ5Lt3A2KkQVL4RkfxQ0E9CTw7noj/ekrpyAPZ2Deb8tUVEQEE/Kfks3VSXRSmJhGjvVdCLSH4o6CehZzCOGVSW5v5grJnRWFlKe89Qzl9bRAQmdynBu82s3cw2jVlWZ2aPmNnW4G9tsNzM7Itmts3MXjCz8/PZ+KnSMzDCnNIIoZDl5fUbK0vVoxeRvJlMj/4e4Orjln0KeNTdVwCPBvcBriF1QfAVwFrg9tw0c3zbO/r455/+hsGR/F6WtmcgtzNXHq+pKsYB9ehFJE8mDHp3fww4fNzia4F7g9v3AteNWf51T3kKqDGz+blq7PF2dBzhiz/fxnN7uvL1FkBq1E0+DsSmNVXFaO9Rj15E8iPTGn2Tu+8PbrcBTcHthcCeMeu1BstOYGZrzWy9ma3v6OjIqBEXLasjZPDE9kMZPX+ycn3RkeM1VJbSMxjP+y8TEZmdsj4Y6+4OnPZpne5+h7u3uHtLQ0NDRu9dXRZl9cJqnspz0HdPQekG0AFZEcmLTIP+QLokE/xtD5bvBRaPWW9RsCxvLjmjnmf3dDIwnL/ecL5LN42VpQAc0AFZEcmDTIP+IeCm4PZNwINjln8gGH2zBugeU+LJi0uW1zOScNbvOv4wQua+uW43D794tNk9A7mfonisdI/+gOr0IpIHkxleeR/wJHC2mbWa2c3AZ4G3mtlW4KrgPsDDwA5gG/BV4E/y0uoxLmyuIxIynsxh+eb2X23j84+8AkA8keTIcCKvpZt0j16lGxHJhwmPMLr7+07y0FvGWdeBW7Jt1OmoKI1w7uIantyRm6BPJp227kFGEk577yCRUOq7MJ8HY2vKo5SEQyrdiEheFMWZsZcsr+eF1m76huJZv9bBI0OMJFLHlp/cfiiv89ykmRkNlaV0qEcvInlQHEF/Rj2JpPPMq9nX6du6j/aqn9x+KK8zV47VVFWqHr2I5EVRBP0FS2spCYd4YvvBrF9rXzCL5MKaMp7YfiivE5qN1VgZU41eRPKiKII+Fg1z3pLc1OnbugcAePf5C9l9uH/0Mn/5LN1A0KPXqBsRyYOiCHpIlW827+th4+5OEllclm9/9yAlkRC//brUzA0/2dwG5PdgLEBjVUxnx4pIXhRN0L9t1Tyi4RDv/rcnOP9/P8It39xIR+/pl0L2dw8yvzrG2U2VzJ1TwsbdqXl08l2j1xBLEcmXogn6VQuqePzWK/iX976et61q4pHNB/h/P/3Nab/O/u4B5lfHMDMuOWMuAJGQUV4SznWTj9GYngZBB2RFJMeKJughFZbXnbeQf/wf5/I/L17Cdze0svPgkdN6jVSPvgyAN5xRD6QOxJrlZy76tKaqYBoE9ehFJMeKKujH+pMrziAaNr7w6NZJPyeZdA70pEo3MCboY/mtz0Nq1A2oRy8iuVe0Qd9YGeOmNzTz/ef28sqB3kk952Bf6mSpdNAvqStnYU1Z3odWAtSWR4mGTT16Ecm5og16gD960xlUlERG562ZyP7gZKl06cbM+NQ1K/nwZcvz1sa01LVjY+rRi0jOFXXQ11aU8PtvXMaPNrWxaW/3hOvvD8bQzwt69ADvPHcB7zp3Qd7aOFZjlS4SLiK5V9RBD3DzpcsIh4wfbZp4tuR0j35BTVm+mzWuJvXoRSQPij7oq8ujvHZhNU/tmHgenP3dg5RGQtSW578mP57GqlLV6EUk54o+6AHWLK/nhdYu+odPPbtl+mSpfA+lPJmmqhjdAyM6O1ZEcmpWBP3Fy+sYSTgbd3Wdcr39XQPH1OenWkNwdmwmZ/SKiJxMVkFvZn9mZpvNbJOZ3WdmMTNbZmbrzGybmX3bzEpy1dhMtSytJRwy1r166knP9ncPsqC6MPV5GHORcNXpRSSHMg56M1sIfAxocffVQBi4Afgc8Hl3PxPoBG7ORUOzURmLsnphNU+dYnbLRHCyVCF79KMXCVedXkRyKNvSTQQoM7MIUA7sB64E7g8evxe4Lsv3yIk1y+p4bk8XA8Pj178P9g0RTzrzCzTiBmBBdRlm8MMX9mU1A6eIyFgZB7277wX+CdhNKuC7gQ1Al7unj3q2AguzbWQurFlez0jCeXZ357iPjw6tLGCPvro8yq1vX8nDL7bxN99/kdQleEVEspNN6aYWuBZYBiwAKoCrT+P5a81svZmt7+joyLQZk9bSXEvIOGn5Zn/XiSdLFcIfX34Gt1xxBvc9vYfP/NcWhb2IZC2b0s1VwKvu3uHuI8ADwKVATVDKAVgE7B3vye5+h7u3uHtLQ0NDFs2YnNE6/UmuK3u0R1+40k3aX7ztbD74hmbufPxVvv/cuB+fiMikZRP0u4E1ZlZuqYHnbwFeAn4BXB+scxPwYHZNzJ01y+t5bnfXuOPU93cPUBoJUVOgk6XGMjP+1ztWURIJ8XLb5CZkExE5mWxq9OtIHXTdCLwYvNYdwCeBPzezbUA9cFcO2pkTa5bXMZxI8uBze9nXNUA8kRx9bF/3IAtqygp2stTxQiGjLBpm8CQHj0VEJiuridbd/W+Bvz1u8Q7gomxeN19amusojYT45PdeBCAcMlbNr+LSM+eyZX8P86oKW58/XnlJmAGdJSsiWcr/FTWmkapYlJ//xeW8cqCX/V2D7OnsZ8POTu789Q7iSefCpXWFbuIxyqJhBkaSE68oInIKsyroARbWlLHwuLHyfUNxNu7qZOX8ygK1anyxaJiBCebnERGZyKwL+vHMKY3wprPyP/LndJWpdCMiOTArJjWbqcqi4ZOeySsiMlkK+mksphq9iOSAgn4aKy8Ja256Ecmagn4aU+lGRHJBQT+N6WCsiOSCgn4ai6lHLyI5oKCfxsqiYYYTyWOmahAROV0K+mmsrCS1ewbjCnoRyZyCfhori4YBVL4Rkawo6KexspLUicsaYiki2VDQT2OjPXoFvYhkQUE/jaVr9CrdiEg2FPTTWCzo0fcr6EUkCwr6aSxdulGNXkSykVXQm1mNmd1vZi+b2RYzu8TM6szsETPbGvytzVVjZ5uyEtXoRSR72fbovwD82N1XAucCW4BPAY+6+wrg0eC+ZKA8mhp1oxq9iGQj46A3s2rgTQQX/3b3YXfvAq4F7g1Wuxe4LttGzlax9MFY9ehFJAvZ9OiXAR3A18zsWTO708wqgCZ33x+s0wY0jfdkM1trZuvNbH1HR0cWzSheqtGLSC5kE/QR4Hzgdnc/DzjCcWUad3fAx3uyu9/h7i3u3tLQMP0u4zcdxHRmrIjkQDZB3wq0uvu64P79pIL/gJnNBwj+tmfXxNkrGg4RDRv96tGLSBYyDnp3bwP2mNnZwaK3AC8BDwE3BctuAh7MqoWznKYqFpFsRbJ8/keBb5hZCbAD+BCpL4/vmNnNwC7gPVm+x6xWFtXlBEUkO1kFvbs/B7SM89BbsnldOapcV5kSkSzpzNhpTqUbEcmWgn6a03VjRSRbCvpprkw9ehHJkoJ+miuLqkcvItlR0E9zMZVuRCRLCvpprjwaZlClGxHJgoJ+mtPBWBHJloJ+mlONXkSypaCf5mLRMIMjSZLJceeGExGZkIJ+mktfZWowrl69iGRGQT/NlWmqYhHJkoJ+mhsNetXpRSRDCvppbrR0o6AXkQwp6Ke5o6WbZIFbIiIzlYJ+mkv36FW6EZFMKeinuZhq9CKSpayD3szCZvasmf0wuL/MzNaZ2TYz+3Zw9SnJ0NHSTbzALRGRmSoXPfqPA1vG3P8c8Hl3PxPoBG7OwXvMWirdiEi2sgp6M1sE/DZwZ3DfgCuB+4NV7gWuy+Y9ZrvyEh2MFZHsZNuj/xfgViCdQvVAl7un6wytwMIs32NWU41eRLKVcdCb2TuAdnffkOHz15rZejNb39HRkWkzil66Rq9x9CKSqWx69JcC7zKzncC3SJVsvgDUmFkkWGcRsHe8J7v7He7e4u4tDQ0NWTSjuEXDRjhkmgJBRDKWcdC7+23uvsjdm4EbgJ+7+43AL4Drg9VuAh7MupWzmJlpqmIRyUo+xtF/EvhzM9tGqmZ/Vx7eY1aJRcP0q0cvIhmKTLzKxNz9l8Avg9s7gIty8bqSUlYSUo1eRDKmM2NngPJoRDV6EcmYgn4GiOm6sSKSBQX9DFAWDSnoRSRjCvoZoCwaVo1eRDKmoJ8Byko06kZEMqegnwFi0bAOxopIxhT0M4BKNyKSDQX9DFCuUTcikgUF/QyQngLB3QvdFBGZgRT0M0CsJIw7DMU1J72InD4F/QygqYpFJBsK+hkgHfQaYikimVDQzwC6bqyIZENBPwOke/QaSy8imVDQzwDpHr1q9CKSCQX9DFCmC4SLSBYU9DNATKUbEclCxkFvZovN7Bdm9pKZbTazjwfL68zsETPbGvytzV1zZycdjBWRbGTTo48Dn3D3VcAa4BYzWwV8CnjU3VcAjwb3JQs6GCsi2cg46N19v7tvDG73AluAhcC1wL3BavcC12XbyNlONXoRyUZOavRm1gycB6wDmtx9f/BQG9B0kuesNbP1Zra+o6MjF80oWmNLN0eG4vzdQ5u57+ndeXu/eCLJ9za0Ek9oygWRYhDJ9gXMbA7wPeBP3b3HzEYfc3c3s3Fn4nL3O4A7AFpaWjRb1ymURkKYwaa93bzzS4+zo+MIZjCvOsYVZzfm/P1++tIBPvHd56koDXP16vk5f30RmVpZ9ejNLEoq5L/h7g8Eiw+Y2fzg8flAe3ZNFDOjLBrm4Rfb6B2M87UPXshr5lXxsfueZUdHX87f78nthwB4Zmdnzl9bRKZeNqNuDLgL2OLu/zzmoYeAm4LbNwEPZt48SVs2t4LLVszl4Y9dxhUrG/nK+y8gEjLW/vsGegdHMnrNVw8e4dM/2MzIcSWap3akgn79zsNZt1tECi+b0s2lwPuBF83suWDZXwGfBb5jZjcDu4D3ZNdEAXjoI28kHDpaFltcV86/3ng+77/radb8n0cpL41QEg7x1lVN/O07VzG2hHYyn3/kFR56fh9vPquBy4MSUEfvEFvb+6gui7JpXw/9w3HKS7Ku8IlIAWUz6uZxdzd3f527vz7497C7H3L3t7j7Cne/yt3VLcyBsSGf9oYz5vLVD1zA9Rcs4qrXNHJm4xzueWInDz2/75j1egZH2Hqg95hl7b2D/GhT6pj5jze1jS5f92qqN/+hS5tJJJ3ndnflelNEZIrpzNgZ7sqVTXz62tX8w7tfx90fvJBzF9fw6R+8xKG+IQAO9g3x7n97gt/64q/Z1n407L/19B5GEs55S2r4yea20RE2T+04REVJmJsuacYM1u9SnV5kplPQF5FwyPjH619H7+AIf/eDl+jqH+b37lxHa2c/sWiYv/7PTbg7I4kk31y3m8tWzGXtZcvp7B/h6VdTP7ye2nGYC5fVUVtRwtlNlTyjOr3IjKfia5E5q6mSj1yxgs//7BWe3d1Je+8Qd93UQmvnALc98CIPbNxLWUmYtp5B/v661bzhzHpi0RA/2tTGmU1z2Nbex/UXLALgwuY6HtiYGk8fCatPIDJT6f/eIvTHl5/BynmVHOgZ5PYbz+eyFQ28t2UxFyyt5TMPb+Erv9rOwpoyrljZSHlJhMvPauQnm9tGh1WuWV4PQEtzLUeGE7zc1nuqtxORaU5BX4RKIiH+48MX84OPvpG3vCZ1YnIoZHzmd1bTPTDC863d/N6apaMHeK957Tzae4f48q92MKc0wuoFVUCqRw9Hh1kOx5Pc+esd7DncX4CtEpFMKeiL1Nw5paycV3XMspXzqvjDNy2nsjTCey9cPLr8ypWNlIRDbNnfw4XNtaNlmgU1ZSyojvHMrk7iiSQf/9az/P1/beE9X3mSnQePTOn2iEjmFPSzzF++/WyeuO1K6ipKRpdVxqJctmIucLRsk9bSXMf6nYe59Xsv8KNNbXz4jcsYiid57x1P8qrCXmRGUNDPMmZGZSx6wvJ3nrsAgDcGgZ92YXMtB3qGeGDjXv78rWfxN+9YxTf/4GLiCee9X3mS+ze0sq29j2RS0xWJTFcadSMAXPv6BbxmfhVnz6s8ZvmlZ86lJBziQ29s5qNXngmkSkD3rV3DTXc/zV9893kAKmMRPnn1Sn5vzdIpb7uInJq5F74n1tLS4uvXry90M+Qk+obizCk9sU+QSDrbO/p4bncXDzzbytOvHuZrH7qIN5/VcMK6Hb1DbNh1mEW15axeWD0VzRYpema2wd1bJlxPQS+5cGQozu/e/gR7uwZ48JZLWd4wh47eIe54bDs/29I+Ws+vLY/yy7+8guqyE8tHInJ6Jhv0qtFLTlSURvjqB1qIhkN8+Ovr+dyPX+ZN//cX3P3fO1k+t4LbrlnJF254PV0DI3zp51sL3VyRWUU1esmZxXXl3H7j+dx45zq+/KvtvPN1C/jTq1awvGHO6DqPbz3IPU/s5MaLl9I8t6KArRWZPVS6kZzbuLuTOaURzmqqPOGxAz2DXPFPv+RNKxr48vsvKEDrRIqHSjdSMOcvqR035AGaqmL80ZvP4Meb20YvcHIy7k5H7xDbO/p4fk8XL7Z2Mx06JiIzjUo3MuX+4LLl3Pf0bj7yzY38znkLeee5C3jtwurRi6UMjiT4wfP7uOeJnWze13PMc1fOq+QPLlvOO89dwJ7Ofv5720FebO1m9cJqrlzZyOK68pO+7/7uAR7YuJeHnttHdVmUa147j6tXz2N+ddlJn9PWPcjG3Z0YcNWqJqKa3E1moLyVbszsauALQBi4090/e7J1VbqZfZ7f08UXH93KY1s7GEk4NeVRqmJRKkojtHUP0Nk/worGOVx/wSLmVceYUxrhYN8Qdz++k98c6KUkEmI4nppDvyoWoWcwDsDyuRXEomH6h+P0DycoiYSoKIkQjRgv7esh6XBRcx09gyOjk7U1VJZSURKmvCRCSSREyCBkxv7uQfZ2DYy2eWFNGX/45uW8p2UxsWh46j80keMUdHilmYWBV4C3Aq3AM8D73P2l8dZX0M9e3f0j/Hjzfl5o7ebIUJy+oQTlJWFuuHAxl5xRf8IlEd2dX73Swc+2HGDV/GouPbOeJXXl7DzUz89fbg9m4HTKSyKURcOMJJMcGUqF/usX13D9BYtYWp86CLy9o4+fbG5jz+F+jgwlODIUZziRxB0cp6a8hPOX1HLB0lo6eoe4/Zfb2Li7CzOoKIlQXhImGg4xFE8yNJIg6c7cylIaK0tprIqxtK6c5voKFtWmfjEMJ5Ik3VlYU07z3HJKIxN/WYwkkhhommgZV6GD/hLg79z97cH92wDc/R/GW19BLzOBu/P0q4f5720HOTKcoH84zlA8SWkkTCwawjAO9g3R3jtIW/cgrZ1O9JclAAAHDklEQVQDxE8yNUQ4ZCytK6exqpT6ilLq55TQVBVjfnWMxsoYvznQy2OvdLDu1UMMx5M0VcVYUFNGbXkJsWiIWDRMSSRENGSEQyGiYaMkEqIkHKKsJMyc0giVsSgVpWHCISNsqcdrykuoryihMhYhnvTUl1Q8QTyRuiDNSMKJJ5PEE04i6VSUhqkqS/3acoeheIKheOrLJxwyIuHULyBITa+R/jUUCm6HQ4aZkUw6w4kkw4kkI/EkSYekO2ZQWRpNfX4TXOfY3Ul66kS9RNJJuBMJGaWRiZ9brCYb9Pmq0S8E9oy53wpcnKf3EpkSZsbFy+u5+LiJ304mnkiyr2uQ1q5+QkHQAuw53M+29j62d/TR0TvElrYeDvUN0z0wcszzlzdUcMOFS6iKRdjbNci+rgH2dQ0wOJJgcCTBcCJJPOkkEkdDdKYeq46EjLJomIQ78aSTTDrpTUkH/KmUREKURkKURsJB8IMHXyZJ9+BX2onGfl4hAzMwjPT3hhF8gYWOXX78azhOMlVJTH2xhlJfdA4kkz76xTb2/dKd7BvXLOWWK86cxKeUuYIdjDWztcBagCVLlhSqGSJ5EwmHWFJfzpL6Yw8Qn7+kdtz1B4YT7O8eoK1nkMW15ac8sDweD0JyYCRB72Cc3sERjgylSkqJpDMcT9LZP8zhI8P0DMSJRozSyNFfBtFwiEg4+BuEVd9QnJ6BEXoG44SD3nP6Cyv9KyD13qmwc4eE+zHhlvRUYKbDOBoOEQqCMJl0+oYS9A6OMDCSIBIyQsEvkHToQup6Ckbq10IkfDRIRxJHf5UMx5MMjqRuE6ybfo4FIQ7ppPbR2+kvBYKw9jFfCUkPti34wjiZUKqxo59F+svKRn/hHF3n+C+S5vr8n0+Sr6DfCywec39RsGyUu98B3AGp0k2e2iEyY5SVhFneMOeYE8xOh5kRDYK6KhYFTj6aSGaXfB3heQZYYWbLzKwEuAF4KE/vJSIip5CXHr27x83sI8BPSA2vvNvdN+fjvURE5NTyVqN394eBh/P1+iIiMjkanCsiUuQU9CIiRU5BLyJS5BT0IiJFTkEvIlLkpsWFR8ysA9iV4dPnAgdz2JyZYjZu92zcZpid2z0btxlOf7uXunvDRCtNi6DPhpmtn8ykPsVmNm73bNxmmJ3bPRu3GfK33SrdiIgUOQW9iEiRK4agv6PQDSiQ2bjds3GbYXZu92zcZsjTds/4Gr2IiJxaMfToRUTkFGZ00JvZ1Wb2GzPbZmafKnR78sHMFpvZL8zsJTPbbGYfD5bXmdkjZrY1+Dv+1SxmODMLm9mzZvbD4P4yM1sX7PNvB9NgFw0zqzGz+83sZTPbYmaXzIZ9bWZ/Fvz3vcnM7jOzWDHuazO728zazWzTmGXj7l9L+WKw/S+Y2fmZvu+MDfrgAuT/ClwDrALeZ2arCtuqvIgDn3D3VcAa4JZgOz8FPOruK4BHg/vF6OPAljH3Pwd83t3PBDqBmwvSqvz5AvBjd18JnEtq24t6X5vZQuBjQIu7ryY1tfkNFOe+vge4+rhlJ9u/1wArgn9rgdszfdMZG/TARcA2d9/h7sPAt4BrC9ymnHP3/e6+MbjdS+p//IWktvXeYLV7gesK08L8MbNFwG8Ddwb3DbgSuD9Ypai228yqgTcBdwG4+7C7dzEL9jWpKdPLzCwClAP7KcJ97e6PAYePW3yy/Xst8HVPeQqoMbP5mbzvTA768S5AvrBAbZkSZtYMnAesA5rcfX/wUBvQVKBm5dO/ALcCwWWXqQe63D0e3C+2fb4M6AC+FpSr7jSzCop8X7v7XuCfgN2kAr4b2EBx7+uxTrZ/c5ZxMznoZxUzmwN8D/hTd+8Z+5inhk4V1fApM3sH0O7uGwrdlikUAc4Hbnf384AjHFemKdJ9XUuq97oMWABUcGJ5Y1bI1/6dyUE/4QXIi4WZRUmF/Dfc/YFg8YH0z7jgb3uh2pcnlwLvMrOdpMpyV5KqX9cEP++h+PZ5K9Dq7uuC+/eTCv5i39dXAa+6e4e7jwAPkNr/xbyvxzrZ/s1Zxs3koJ8VFyAP6tJ3AVvc/Z/HPPQQcFNw+ybgwaluWz65+23uvsjdm0nt25+7+43AL4Drg9WKarvdvQ3YY2ZnB4veArxEke9rUiWbNWZWHvz3nt7uot3XxznZ/n0I+EAw+mYN0D2mxHN63H3G/gN+C3gF2A78daHbk6dtfCOpn3IvAM8F/36LVL36UWAr8DOgrtBtzeNncDnww+D2cuBpYBvwXaC00O3L8ba+Hlgf7O/vA7WzYV8DnwZeBjYB/w6UFuO+Bu4jdRxihNQvuJtPtn8BIzWycDvwIqlRSRm9r86MFREpcjO5dCMiIpOgoBcRKXIKehGRIqegFxEpcgp6EZEip6AXESlyCnoRkSKnoBcRKXL/H9LNtGk2FVgcAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(grads)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you can see, the gradients decay as time progresses. This is one of the factors that makes simple RNNs more difficult to train compared to LSTMs. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The gradient dynamics of LSTMs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, let's compare the same plot with LSTMs. Though this might not be very well known, the original formulation of the LSTM did not have a forget gate; we'll be using the formulation without the forget gate first and then see how the forget gate changes the dynamics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = NaiveLSTM(50, 125)\n",
    "hidden_size = lstm.hidden_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lstm_step(x_t, h_t, c_t, W_ii, W_hi, b_i, W_if, W_hf, b_f,\n",
    "              W_ig, W_hg, b_g, W_io, W_ho, b_o, use_forget_gate=False):\n",
    "    i_t = torch.sigmoid(x_t @ W_ii + h_t @ W_hi + lstm.b_i)\n",
    "    if use_forget_gate:\n",
    "        f_t = torch.sigmoid(x_t @ W_if + h_t @ W_hf + lstm.b_f)\n",
    "    g_t = torch.tanh(x_t @ W_ig + h_t @ W_hg + lstm.b_g)\n",
    "    o_t = torch.sigmoid(x_t @ W_io + h_t @ W_ho + lstm.b_o)\n",
    "    if use_forget_gate:\n",
    "        c_t = f_t * c_t + i_t * g_t\n",
    "    else:\n",
    "        c_t = c_t + i_t * g_t\n",
    "    h_t = o_t * torch.tanh(c_t)\n",
    "    return h_t, c_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate \n",
    "h_0, c_0 = (torch.zeros(hidden_size, requires_grad=True), \n",
    "            torch.zeros(hidden_size, requires_grad=True))\n",
    "grads = []\n",
    "h_t, c_t = h_0, c_0\n",
    "for t in range(100):\n",
    "    h_t, c_t = lstm_step(\n",
    "        test_embeddings[:, t, :], h_t, c_t,\n",
    "        lstm.W_ii, lstm.W_hi, lstm.b_i,\n",
    "        lstm.W_if, lstm.W_hf, lstm.b_f,\n",
    "        lstm.W_ig, lstm.W_hg, lstm.b_g,\n",
    "        lstm.W_io, lstm.W_ho, lstm.b_o,\n",
    "        use_forget_gate=False,\n",
    "    )\n",
    "    loss = h_t.abs().sum()\n",
    "    loss.backward(retain_graph=True)\n",
    "    grads.append(torch.norm(h_0.grad).item())\n",
    "    h_0.grad.zero_()\n",
    "    lstm.zero_grad()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1a7bf76eb8>]"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAD8CAYAAAB+UHOxAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xd8XNWZ8PHfo96tLsuWZBtbbuCCEca0UEwvoa6BFCCw62QDC9k3b1jCJi/ZJYUsCdmQEBIIBMiymOpgQjEGTC+23C1XuciSrN7bjKac94+5Mx5ZXRppJM3z/Xz08cyZO3fO/QzcZ85zmhhjUEopFXrCgl0BpZRSwaEBQCmlQpQGAKWUClEaAJRSKkRpAFBKqRClAUAppUJUvwFARHJFZL2I7BKRIhG52yr/iYiUi8hW6+8yv/f8UESKRWSviFzsV36JVVYsIveOzCUppZQaCOlvHoCIZAPZxpjNIpIIbAKuBlYArcaYXx13/HzgeWApMAV4F5htvbwPuBAoAzYCNxljdgXucpRSSg1URH8HGGMqgArrcYuI7Aam9vGWq4BVxhg7cEhEivEEA4BiY8xBABFZZR2rAUAppYKg3wDgT0SmAycDXwJnAneKyM1AIfB9Y0wDnuDwhd/byjgWMEqPKz+tr89LT08306dPH0wVlVIq5G3atKnWGJPR33EDDgAikgC8AnzPGNMsIo8BDwDG+vfXwG1DrK//56wEVgLk5eVRWFg43FMqpVRIEZGSgRw3oFFAIhKJ5+b/nDHmVQBjTJUxxmWMcQNPcCzNUw7k+r09xyrrrbwLY8zjxpgCY0xBRka/AUwppdQQDWQUkABPAruNMQ/7lWf7HXYNsNN6vAa4UUSiRWQGkA9swNPpmy8iM0QkCrjROlYppVQQDCQFdCbwTWCHiGy1yu4DbhKRxXhSQIeBbwMYY4pE5EU8nbtO4A5jjAtARO4E1gLhwFPGmKIAXotSSqlB6HcYaDAVFBQY7QNQSqnBEZFNxpiC/o7TmcBKKRWiNAAopVSI0gCglFIhalATwZRSSo2817Z6Rsh/ddEUPAMxR4a2AJRSagypb+vk/jVF/O+XR0b8szQAKKXUGPLLt/bQanPywNUnjeivf9AAoJRSY8amkgZeKCzl9rNmMDsrccQ/TwOAUkqNAU6Xmx/9bSfZk2K4a3n+qHymBgCllBoD/vpFCbsrmvnxFfOJjx6d8TkaAJRSagxYvaWcxbnJXHrS5FH7TA0ASik1BrTZnUxNjh3xjl9/GgCUUmoMsDncREeO7i1ZA4BSSo0BdqeLmMjwUf1MDQBKKTUG2BxuYiI0ACilVMjpcLiI0RSQUkqFFofLjcttiNUUkFJKhRabwwWgfQBKKRVqbA43gKaAlFIq1HhbANHaAlBKqdBid2oKSCmlQpIvBRShKSCllAopHdoJrJRSocnbBxAbpQFAKaVCyrEUkAYApZQKKcfmAWgfgFJKhRSdCKaUUiHK5vSkgHQ5aKWUCjF2bQEopVRo6ui0AoB2AiulVGixOV2EhwmR4aO3HSRoAFBKqaDzbAYTNqr7AYMGAKWUCjqbY/S3gwQNAEopFXQ2h3tsBgARyRWR9SKyS0SKRORuqzxVRNaJyH7r3xSrXETkEREpFpHtIrLE71y3WMfvF5FbRu6ylFJq/LA5XaM+BBQG1gJwAt83xswHlgF3iMh84F7gPWNMPvCe9RzgUiDf+lsJPAaegAHcD5wGLAXu9wYNpZQKZXaHa9RHAMEAAoAxpsIYs9l63ALsBqYCVwHPWIc9A1xtPb4KeNZ4fAEki0g2cDGwzhhTb4xpANYBlwT0apRSahwKxobwMMg+ABGZDpwMfAlkGWMqrJcqgSzr8VSg1O9tZVZZb+VKKRXSbA73qK8ECoMIACKSALwCfM8Y0+z/mjHGACYQFRKRlSJSKCKFNTU1gTilUkqNabaxmgICEJFIPDf/54wxr1rFVVZqB+vfaqu8HMj1e3uOVdZbeRfGmMeNMQXGmIKMjIzBXItSSo1LY3YYqHhmJjwJ7DbGPOz30hrAO5LnFuA1v/KbrdFAy4AmK1W0FrhIRFKszt+LrDKllAppNoc7KKOAIgZwzJnAN4EdIrLVKrsPeBB4UURuB0qAFdZrbwKXAcVAO/AtAGNMvYg8AGy0jvtPY0x9QK5CKaXGMbszOC2AfgOAMeYToLf5yct7ON4Ad/RyrqeApwZTQaWUmug8S0GMwRSQUkqpkWUbD8NAlVJKBZbD5cbpNsSOxU5gpZRSIydY20GCBgCllAoqm8OzHaSmgJRSKsR4WwDR2gJQSqnQYndqCkgppUKSLwUUoSkgpZQKKdoJrJRSIarDCgBjejVQpZRSgXcsBaQBQCmlQsqxFJD2ASilVEjRPgCllApRNqcnBTRWN4VXSik1QuzaAlBKqdDkSwFpJ7BSSoWWDoeL8DAhMry3bVdGjgYApZQKIs9mMGF4dt8dXRoAlFIqiIK1ITxoAFBKqaCyOdwaAJRSKhTZnK6gDAEFDQBKKRVUdocrKCOAQAOAUkoFlScFpC0ApZQKOR3aCayUUqHJ5nARqwFAKaVCjw4DVUqpEGVzuHUUkFJKhSK7U1sASikVkjxLQWgAUEqpkOPpA9AUkFJKhRSHy43TbTQFpJRSoca7F4AOA1VKqRBjc3i2g9QUkFJKhRhvCyB6rLYAROQpEakWkZ1+ZT8RkXIR2Wr9Xeb32g9FpFhE9orIxX7ll1hlxSJyb+AvRSmlxhe7M3j7AcPAWgBPA5f0UP4bY8xi6+9NABGZD9wInGi95w8iEi4i4cCjwKXAfOAm61illApZvhRQRHCSMRH9HWCM+UhEpg/wfFcBq4wxduCQiBQDS63Xio0xBwFEZJV17K5B11gppSYI34bwY7gF0Js7RWS7lSJKscqmAqV+x5RZZb2VK6VUyDrWCTy+AsBjwExgMVAB/DpQFRKRlSJSKCKFNTU1gTqtUkqNOR3jcRioMabKGOMyxriBJziW5ikHcv0OzbHKeivv6dyPG2MKjDEFGRkZQ6meUkqNC8dSQONoGKiIZPs9vQbwjhBaA9woItEiMgPIBzYAG4F8EZkhIlF4OorXDL3aSik1/gW7D6DfTmAReR44F0gXkTLgfuBcEVkMGOAw8G0AY0yRiLyIp3PXCdxhjHFZ57kTWAuEA08ZY4oCfjVKKTWO2JyePoBgLQc9kFFAN/VQ/GQfx/8M+FkP5W8Cbw6qdkopNYHZx/EoIKWUUsPgSwHpctBKKTXxVbfYfDd+m8NNmEBkuASlLhoAlFJqlLjdhst++zEPvrUH8AwDjY0MR0QDgFJKTQhut+HBt/ZwpK69S3lJfTu1rZ28trUch8sd1A3hQQOAUkoF3NGmDv744QFeLCztUl50tAmAhnYHnxTXeraDDGIA6HcUkFJKqcFpsTkB2FbW2KV819FmIsKEuKhwXt96FLvLHbQhoKAtAKVUCHqnqJIXNh4ZsfP7AkBpI8YYX3nR0WZmZSZw6UnZrC2qpKndEbQRQKABQCkVgp7+7DB//PDgiJ2/xeYAoNnm5LBfP8CuimZOnDKJKxdNoa3TxZeH6oK2DARoAFBKhaDKZht1rfYRO7+3BQCeVgB4hn/WtNiZPyWJ02emkZ4QjcMVvA3hQQOAUioEVTXZaLY5cbjcI3J+bwtABLZaAWDX0WYATpySRHiYcMVCz5JqwVoJFDQAKKVCTKvdSVunZyJWQ3vniHxGs9UCWDh1kq8juMgKAPOnJAFw5aIpQPCWgQANAEqpEFPZZPM9bmhzjMhntNicRIWHcer0VIqONuNwudlV0UxuaixJMZEALMlLZnZWAlNTYkekDgOhw0CVUiGlqvlYAKhrswOJAf+MFpuDxJgIFuUm0/nJIfZWtrDraDMnZk/yHSMirLnzLCLDtRNYKaVGhX8AqG8buRRQYkwEi3OTAfi0uJbDdW2+9I9XTGQ44WHBWQYCNAAopUJMZbN/CmhkAoCnBRBJTkosKXGRvFBYijGeDuCxRAOAUiqkVDXZiI/ydLzWjVgA8LQARIRFuckcrGkD6NYCCDYNAEqpkFLVbGdKciyTYiNHuAXg6WJdlONJA6XGRzE5KWZEPm+otBNYKRVSKpttZCXF4HSbEW4BeEb7ePsB5mcnBW3Z595oC0ApFVKqrACQGh81YvMAvCkggIU5npE/J04dW+kf0BaAUiqEuN2G6hY7kydF09ThoKyhvf83DZLLbWi1H2sBpCVE89StBSy0UkFjibYAlFIho7bNjsttyEqKIW2EWgCtds8s4KSYY7+vz5+bRXpCdMA/a7g0ACilQkZVk2cBuKykGFLio6hv6+yyXHMgeNcBSowZ+wkWDQBKqZDhnQQ22WoBOFzG94s9ULwrgXpTQGOZBgClVMjwTgLztgAg8LOBjwUAbQEopcapTSUNNNtGZrG0YKlqthEmkJ4QRVqAAsDxKaRjKSBtASilxqGd5U1c99hnPPvZ4WBXJaAqm2xkJEYTER4WkBbA/qoWTrx/LfurWnxl2gJQSo1rD63dC8DB2rYg1ySwqlrsZFmzcb0tgOFMBvukuJb2The7Kpp9ZdoJrJQat748WMeH+2oIEyitD/w4+ZFkjOHxjw6wasMROp3dd/uqarL5AkCqFQCGsxzE9rImAGpajm0v6d0MJmkcpIDGfohSSo0aYwwPrd1LVlI0S2ekseFQXbCrNCh/+OCAr/Xyu/eLufP8WVy3JIeoCM9v3cpmG0tnpAIQFxVOVETYsFJA3t2+/JeYbrE5iQwXoiPG/u/rsV9DpdSoWb+3msKSBv7l/HxmZyZQ1WzH5nAFu1oD8uaOCh5au5erFk/hL986lYzEaH746g7ueXkbADaHi6YOB1lJnglZIkKaNRdgKJptDt8qn1XNx1oA3qWgx9q6Pz3RAKCUAjzLJDy0dh/T0uK44dRc8tLiAEZkuYRA21bayL++sJVTpqXwy+sWct6cTFZ/9wy+deZ01mw7Sml9u+9XepbfipwpcUMPADut9E9EmFDd0rUFMB7y/6ABQCllKW/sYHdFM986YzqR4WHkpnoCwJEx3g9gjOG7z20mIzGaP33zFN8m6yLCyq+cQJgIT316yLcX8ORJxwJAWkIU9UNcDmKbFQBOOyGV6uNaAOMh/w8aAJRSFu9oGO8v/zxvAKgb2wGgrq2T8sYObj9rRrf1drInxXLloim8uLGU/dWtQOBaANvLGslLjWNOVhLVLf4BYAK1AETkKRGpFpGdfmWpIrJORPZb/6ZY5SIij4hIsYhsF5Elfu+5xTp+v4jcMjKXo5QaqrpWz00sNd5zE02LjyIuKpwj9R3BrFa/yho89ctNievx9dvPmkFbp4vHPjgAdA0AqcPoA9he1sTCnElkJkXTanf6lpSYUAEAeBq45Liye4H3jDH5wHvWc4BLgXzrbyXwGHgCBnA/cBqwFLjfGzSUUmODtwXgHR8vIuSlxo35FJC3jyInNbbH10+aOollJ6RS3thBbGR4l1U6U+OjaLE5exwy2peaFjvljR0sykn2dSpXW30M3k7g8aDfAGCM+QioP674KuAZ6/EzwNV+5c8ajy+AZBHJBi4G1hlj6o0xDcA6ugcVpVQQ1bVaASAhyleWmxo35ucCeFsAU5N7DgAA/3jWCQBkJUV3GZ3jnQvQOMh+gO3W8M9FuclkJnpaFN40UPMEawH0JMsYU2E9rgSyrMdTgVK/48qsst7KlVJjRF2rnbiocOKijt28vC2AQC+ZHEhlDe0kx0X2+av7/LmZzMyIZ1pafJfy1CHOBt5W1kSYwElTk3wtgKpmW7fNYMa6YYcpY4wRkYD91yEiK/Gkj8jLywvUaZVS/ahr6/TdEL3yUuPocLiobe0kI3HsbWgCUFrfQU5K77/+AcLChOf/aVm3sflDnQ28vayR/MxE4qIiyLT6FKqb7T1uBjOWDbUFUGWldrD+rbbKy4Fcv+NyrLLeyrsxxjxujCkwxhRkZGQMsXpKqcGqa+sk7bhRNHnjYChoWUM7Ock9dwD7y0yK6RbEhtICMMb4OoABEqMjiIkMo7rFNq7WAYKhB4A1gHckzy3Aa37lN1ujgZYBTVaqaC1wkYikWJ2/F1llSqkxoq7VTvpxLQDvXICR6AcorW/n0fXFuN1DTyAYYyhr6L8F0BtfC2AQfQBlDR3Ut3WyMNezx6+IkJUUQ1WzfVxtBgMDGwb6PPA5MEdEykTkduBB4EIR2Q9cYD0HeBM4CBQDTwDfBTDG1AMPAButv/+0ypRSY0Rda2eXDmDAd2MdiRbAS4WlPLR2L9vLm4Z8jtrWTuxO95ADQHKs50bt7QD3t/lIAw+v29ctQBWWeG5di/02ec9KjKGq2TauloKGAfQBGGNu6uWl5T0ca4A7ejnPU8BTg6qdUmpUGGOoa7P75gB4xUSGMzkpZkQCwF5rDf33d1exODe5n6N75hsC2sscgP5EhIeRHBfZbS7A7opmbnlyAy12J+fMTueUaam+19btqiIzMZoTpyT5yjKSotl9tHlcbQYDOhNYKYVn6KLDZUg/rgUAjNhcgH1Vnpm57+2p7ufI3vkmgaUOLQAApMZ1XQ6ivLGDW/+ygfjoCKLCw3hje6XvNZvDxQd7a7hwfhZhYcc6lMdrC0ADgFLK9wv4+BQQjMxcAJvDxeG6NpLjIik62uxbp2ewfHMAhpgCAms2sJUCOtrYwS1PbaC908Uzty3lK7MzeGtnhS8N9Km1AcxFJ07uco6spGjaOl2+PYc1ACilxg3vMhBp8d2HeualxlHZbAvostDF1a0YA7eeMR2A94fYCihraCclLpKE6KHfcFPjo9hV0cylv/2YMx58nyN17TxxcwFzJidy+cLJVDTZ2FLaAMDaokoSoyM4/YS0LufItOYCFFvrDelicEqpcaO2h1nAXnlpsRjjSY0Eyt5KT/7/ioXZ5KTEDiMAdAw5/+81OyuR9k4nk2IjuOeSObz1vbNZZt3gL5iXRVSEJw3kdLl5d3c1583N9G0w45VlzQY+UNM6bjaDAd0RTCkF1LX13QIAz0igmRkJAfm8fVUtRIWHMT0tnuVzM3mxsAybw+VbynmgyhramZ2VOKy6fP+i2dy1PL/bTR08nbnnzM7gzR0VXDg/i/q2Ti4+Lv0DXVsA42UzGNAWgFKKY8Mgj58JDCMzF2BvVQszMxOICA/jvLmZdDhcfH6w7+0nm9odPPPZYZraPSNthjsHwEtEerz5e12+IJvKZhsPrd1DVEQY58zpPkHVOxt4PK0EChoAlFJ4OoGTYiJ6vBFmJEQTGxlOSQD3BdhX2cKcLE9rYtkJacRGhvP+7p7TQG634cXCUs7/9Qfcv6aIP33kWda5ptVuzQEYXgqoP8vneVI+m480ctas9B77GxKjI4i1Wi8aAJRS40ptq73bZipeIkJuamzAWgDNNgdHm2zMnuxJ3cREhnNWfjrv76nutuhci83Bij99zj0vb2d6ejxL8pJ5ZXMZTpfbNwJouC2A/njTQAAXzc/q8RgR8aWBEqPHRwcwaABQSuFJAfWU/vEK5FyA/dYEsDl+ufvlczMpb+xge1nXWcEvbyqjsKSBX1y7gJe+fTorvzKTqmY7H++v9QsAI9sCAPjaaXlMTorhwl4CABzrCNYWgFJqXKlrs/c4AsgrJyWOsoaOgCwLvbfSM1TSv/P2soXZxEaG89yXJb4yYwwvbCxlUc4kblqaR1iYcP7cTFLjo3hpU6lvFvBw5gAM1HlzMvnivuXdFsvz52sBjJMhoKABQCmFpw+gr5tbXmocrXYnDVYH7HDsq2ohPiq8ywYuSTGRXH3yVNZsO+rr5N1e1sSeyhZWnHpsIeGoiDCuOXkq63ZVsaOsadhzAAIpU1sASqnxxuU21Ld1dlsJ1F8gl4XeW9lCflZil6UUAL6xLA+bw83Lm8sAWLWxlNjIcL66aEqX41YU5OJwGd4uqhyV9M9AeTeGGS97AYAGAKVCXmN7J27T8xBQr7y0wAWAfVUtXfL/XidOmcSSvGSe+6KENruT17cd5bIF2d1SKnMmJ7IoZxLGQG4v+wAHg3ezeU0BKaXGDd9m8H2kgHJTAjMXoLbVTl1bp28E0PG+sWwaB2vb+NHfdtJqd3Lj0twej7u+wFM+lloAmYnePgBtASilxola7zpAfXQCx0aFk5EYPewAsK+y+wggf5ctyCYlLpLVW8o5ISOegmkpPR731UVTyE2NZUlez68Hw8zMBKIjwpiZGZjZ0qNh/IQqpdSI8K4E2ts8AK/clNhhp4D2WAFg9uSeb5IxkeGsKMjlTx8d5IaC3F6XVJgUG8nH95w/rLoEWlZSDDt+cnGfs4rHGg0ASoW4vpaB8JeXGkdhScOwPuuT4lqmJseS0Uewuf2sGdS1dXLDqT2nf8ay8XTzB00BKRXy6lrtiEBKXP8B4GhjBw6Xe0if02p38sn+Wi45aXKfi6VlJsXwq39YRHI/9VHDpwFAqRBX29ZJalwU4WF9r2CZmxqH23g2TRmK9Xuq6XS5e1xNUwWHBgClQlx9D5vB9yTPtyro0ALA2qJK0hOiOKWXjl01+jQAKBXi6trsPe4DcLzcYUwGszlcrN9TzYXzs/ptaajRowFAqRBX19pJ6gBaAFlJMUSFhw0pAHxaXEtbp0vTP2OMBgClQlxtq73PZSC8wsOEnJT+l4W2OVz88NXt3P70Rt+6Pt69dM+YmR6QOqvA0GGgSoWwTqebZpuzz1nA/nL7WRa6rKGdb/91E0VHm4kIE1b86XOe+taprNtVxfnzuu+lq4JLA4BSIayhvffN4HuSlxrH1tLGHl/bcKie7/zPJhxON0/eUkBMZDgrny3k0v/+iGabk0s0/TPmaDhWKoRVNNkA+pyY5S8vNY6mDgdNHV2Xha5ssvFPzxaSHBfJa3eeyfJ5WZw5K51VK08nMjyMmMie99JVwaUtAKVC2N7KZqDr5ix98a6+WVrfzqSpkwDPnr3/96VtdDrd/PnmAk7IOLbMw4KcSbxx19nUttqJi9LbzVijLQClQtieyhZiI8N9Y/z7k5vafVXQv3x2mE+Ka/l/V87vcvP3mjwphpOsYKHGFg0ASk0gnx+o4zt/3YRzgMs17K1sYXZWQrfNWXrjPxfA6XKzo6yJX769hwvmZXHjOFy7J9Rpm0ypCWRtUSVvF1WyrayRU6al9nmsMYY9lS1cMC9zwOdPiokkJS6SB9/ewy/e2gN4VhH95XUL+lzfR41NGgCUmkAO17UB8MHemn4DQE2rnfq2TuZOThrUZ/z06gXsPNpEbGQ4cVHhnDc3c8DDSNXYogFAqQnkUO2xAPD9i+b0eexea23+ub3sztWbyxdmc/nC7KFVUI0p2gegVACs31PNuQ+tp6JpaAulBYLD5aasoYPE6Ah2lDdR02Lv83hvAJgzyACgJo5hBQAROSwiO0Rkq4gUWmWpIrJORPZb/6ZY5SIij4hIsYhsF5ElgbgApYJtR1kTd/zvZg7XtbNpmBumDEdpfTsut+H6ghwAPtpX0+X1RmvSl9eeyhYyEqM1fRPCAtECOM8Ys9gYU2A9vxd4zxiTD7xnPQe4FMi3/lYCjwXgs5UKqrKGdm57ZiPJsZGECeyvah3we1/dXMaRuuFtsejPm/+/fEE26QnRfOAXAF7cWMopP32XXUebfWV7KpsHnf5RE8tIpICuAp6xHj8DXO1X/qzx+AJIFhFNJKpxq9nm4LanN2JzuHj6tqXkpcZRXD2wANDU4eD/vLiNxz48ELD6HKzxBIATMhI4Z3YGH++vweU2NLR18vO3duNyG14sLAXA5Tbsr2rtdXN2FRqGGwAM8I6IbBKRlVZZljGmwnpcCWRZj6cCpX7vLbPKlBqXXthQyr6qVv74jVOYnZXIrMwE9le3DOi9uys8v8Q3ldQHrD6H69pIiokgJS6Sc+dk0NjuYGtpI/+1di8tNieLcpP529Zy7E4Xh+vasDvdmv8PccMdBXSWMaZcRDKBdSKyx/9FY4wRETOYE1qBZCVAXl7eMKun1MhZt7uKedlJnDnLs8TxrMxEPtxXg9PlJiK8799W3gCwr6qVxvbOgOx/e7i2nRnp8YgIZ+enEybw6Ppi1u+t5vYzZ3BWfjq3/mUj6/dU47b+r5yXPbghoGpiGVYLwBhTbv1bDawGlgJV3tSO9W+1dXg54D9VMMcqO/6cjxtjCowxBRkZuniUGpsa2jopPFzfZRJVfmYCDpehZAAbpvjn4jcfCUzH8aHaNmakxwOQHBfFyXkpvL+nmoyEaO6+IJ+z8zPISormpcIy9lQ0EyYwK7P70g0qdAw5AIhIvIgkeh8DFwE7gTXALdZhtwCvWY/XADdbo4GWAU1+qSKlxpX1ez2/oi+Yl+Ury8/y3EwH0hG8u7KZU6enEBkubDw8/ABgc7g42tTBdCsAAJxnrb7575fPIzEmkvAw4dolOXywr4ZPimuZnh5PTGT4sD9bjV/DaQFkAZ+IyDZgA/CGMeZt4EHgQhHZD1xgPQd4EzgIFANPAN8dxmerEPD+niqW/fw9Xios7f/gUfbu7ioyE6NZ4LfI2UxrIbTifvoBHC43+ypbOTkvhROnTKLw8PD7AY7Ut2MMvhYAwC1nTOexry/hq4um+MquW5KDy23YfKSReYOcAawmniH3ARhjDgKLeiivA5b3UG6AO4b6eSq0NLZ3cs/LO2jucPCDl7fz+cE6HrjqJOKjgz953e508eHeGr66eGqXRdTioyOYmhzL/n5GAh2saaPT5WZ+dhJut+HZL0qwO11ERwz917h3BvD0tGMBIDEmkksXdB1oNyszgZPzktlypFE7gJXOBFZj03+8vovG9k5e+ecz+N4F+azeUs6Vv/+E+rbObsduK23E7nSNWt2+PFhPW6erx0XU8rMS+k0BeTuA52UnUTA9lU6nm53lTcOqky8A+LUAenP9KZ6JYhoAlAYANWp+//5+rv3Dp2zrZUtBr3W7qli9pZw7zpvFgpxJfO+C2Tx721IO1rT5xrF77Sxv4qpHP+WVTd3GE4yYd3dXERMZ5hv94y8/M4EDNa243L0PfttV0UxURBgnZMRzyrQUAAqH2Q9wuLaNtPgoJsVG9nvs9afk8MDVJ3HenIGvAqomJg0AalRUNHXwyPvFbC1t5Jr1oqjnAAASVUlEQVQ/fMrP39xNR2f3X+2N7Z3ct3oH87KTuOO8Wb7ys/MzWJKXzKuby/BkEz28AaHo6PB+QQ+UMYZ3d1Vxdn5Gjx2o+ZmJ2J1uyhp6Hwm0u6KZ2VkJRIaHkZEYzYz0+GF3BB+qbRvQr3+A6Ihwvrlsmm7QrjQAqNHx+/eLMcbwxl1nc+PSPB7/6CBXPfoJLbaue8v+x+u7aGjr5KHrF3a7QV2zJId9Va0UWUMobQ4Xf9vi+eU/mCUYetPR6aKp3dHnMbsqmjnaZONCv9E//mb1MxLIGMOuo81dOmBPmZbCppL6LoFtsA7XtXXJ/ys1EBoA1IgrrW/nxcJSbjg1l3nZSfz8mgX85dZTKa5u5b7VO303vrVFlb7UT09bCF65MJvIcOHVzZ6b/ju7qmi2OZmZEc/eqpZh3UAB7lq1hev++Fmf53l5UxnhYcJ5c3tOn3jH1ffWEVzTYqeurZP5U44FgIJpKTS0OzhgLeUwWG12J1XNdk7I0ACgBkcDgBpxj7y3HxHhzvPyfWXnzc3k/1w4m9e3HeWFjaXUt3Xy76t3cOKUJO48f1aP50mOi2L53CzWbCvH6XLzUmEpU5Nj+eayaTR1OKjuZ/njvhyubWPdriqKq4+1MI5X0dTBc18e4dqTp5KR2PMKmkkxkWQlRXdZE8i/g3qXXwewV8F0z8Ytnx+o7XKuvZUt3L1qC49/1Pd6Qd5F4LQFoAZLA4AaUQdrWnl1SznfOG0akyfFdHntn8+dxVmz0vnJ60Xc8dxmmjoc/HrFIiL7WEbhmiVTqW3tZNXGUj4pruX6U3KYY6VTvOvbA7jdhot+82G/N0+vZz8vISJMCA8T3tjR8/zER9d70lh3Lc/v8XWv/MxE31yANduOsuAn7/DwO3uBngPAzIx4clNj+fFrRVz96Kf8+eOD3L1qC5f89iNe23qUR9cf6HOP38O1nv6G6ekD29hdKS8NAGpE/ebd/USFh/HP587s9lp4mPDwDYtIiI7k84N1fO+C2f1uT3jenEyS4yJ54O+7MMYzomW2lXffV3UsAOyvbmVfVStPfHyITmffG6S32Z28VFjKZQuyOWNmGm9sr+iWBiqtb+eFjaWsKMj1bYzeG8+icK28sPEId6/aQmJ0BI+8X8zv3tvP7ooWpibHdhmtIyK88p0z+OGlc+l0uvnpG7tZW1TJt78yk19cu4CmDkev+wwcrm1jzTZPSkxbAGqwgj+rRk1YW0sbeX3bUe44b2avKZPMxBieuPkU1hZV8e2vnNDvOaMiwrhy4RT++kUJZ85K892M0xOiuwSADYfqAE/O/Z1dlVyxcEqP5wNYvaWcFruTW86YRnF1K//2yg6KjjZ36Yf43ftWGquX9JS//KwE2jtd/NsrO/jK7Awe+/oSfvy3nfx63T6iI8I4O7/7GleZSTF8+5yZfPucmRyubSMpNpLU+Cha7U7uf62Id3dXcdoJab7jt5Y28vM3drPhcD1hAl87LW9MTJJT44u2ANSIMMbw07/vIj0hin8+t++b5sl5Kdx76dx+V9D0WlGQiwh8bek0X9nsrAT2+o28+fJQPZOTYshNjeWvn5f0Wc9nPz/MSVOTWJKXwkXzJ3dLAx2qbeOVzeV8/bQ8sifF9ls/byvmovlZPHHzKcRHR/Bf1y/kioXZ2J1u5mf3PQFreno8qfGe1UEToiNYNjON93ZX+143xvCDl7ZxqK6Ney6Zw2f3Lufn1yzot15KHU9/MqgR8fbOSgpLGvjZNSeREOBfpgtyJvHFD5eTlXSsT2F2ViIvFpbidhtEYMOhek6fmca87CQefGsP+6pamJ2ViDGGP398iJL6Nk61Ol/3VbXy0PULERFS4qM4Y2Yab+6o4J6L59DW6eLuVVuIjug5jdWTJXnJvPSd0zk5N9kX1CLCw/jNDYtZnJvMZQsGtw/ShfMy+fFrRRyoaWVmRgIf7qthf3UrD69YxLVLcgZ1LqX8aQtABVyn082Db+8hPzOBGwpy+3/DEPjf/MGzrEF7p4vyxg5K6tqpbrGzdEYqKwpyiYoI43++KMEYw3/+fRc/e3M3L28q4+5VW7l71VZS4iK50m/BtCsWZlNS187mI42sfLaQoqPN/O6mk8lMjDm+Gj0SEU6dntqtRRMZHsY/nn0CU5L7b0X4O9+ac/De7ioAnvzkEJmJ0X2mtZQaCG0BqCE7UtdObFR4l/y+22147IMDlNS185dvnTrgtM5wzba2NtxX1UJtq2c46Gkz0kiNj+LyBdm+uQPPfl7CbWfO4L7L5rKnsoXCw/XkZyV2mdV70fzJ3Ld6J7c9vZGmDgcPr1jE8l4mfo2GqcmxzM9O4t1d1XxldgYf76/lBxfP0Zm8atg0AKghae90cvnvPsbmcHHFwincfPo0yho6eHR9MXsqW1g+N5NzZ4/ehj7etfj3VrVQXN1KWnwUM62JUd9YNo3VW8p59vMSbj1jOj++Yh4iwklTJ/U44cybBvp4fy0/unzemEizXDAvk9+vL+ZXa/cRGxnO10/T3fLU8GkAUEPy9s5KWmxOLlswmXesGbzgGdP+8IpFfHXRFESkn7METlJMJFMmxbCvsoXCkgaWzkj1ff6SvGSuXDSFnJRY7rl4zoDq9Z9XncTuiuZB5+tHygXzs3jk/WLe3V3FzadPC8gWkkppABgFTpebw3VtlNZ3cOqM1IB3igbDy5vKyEuN49GvLaHV7uT1bRWkxkdx4fwswsNG78bvb/bkRD49UEdNi53bz5rhKxcRfnfTyYM614z0+C6bqwTbSVMmkZkYTU2rnW+dOaP/Nyg1AOP/TjRGHaxp5Y3tFby7p5o9Fc3YrclISTERfPP0adx6xoxex8Yfz+ly8+G+GnJS4sbEGu7ljR18frCOu5fnIyIkxkTytTGQkpiTlcgHe2sAWDojNci1CaywMOGO82ZR3WIbU4FJjW8aAALEGMOeyhbeKapibVGlb8r/krxkvrlsGvOyk0hLiGLVhlL+8MEBnvj4EFcsyGbFqbmc5peuMMbQ1umixeagvq2TtTsrWbWxlOoWO4nREfzPP57GotzkYF4qqzeXYYxne8GxJN/qCE6Mieh3RvF4dMsZ04NdBTXBaAAYpmabg//5ooTnNxyhtL4DEViSl8KPr5jP5Quyu61/c+6cTA7UtPLUJ4dYs/Uor24pJy81jqTYCGpbOqltteP020xEBM6dncG/nzyVX7+zj28++SX/+0/Leuy8HA3GGF7ZXM5pM1L7XRJhtM2xAsCp01ODloZSajzRADBEje2d/PHDgzz3RQktdidnzkrju+fOYvm8zH7Hi8/MSOBn1yzgR5fP580dFby+/SgA8yYnkZ4YTXJsJEmxkSTGRLA4N5mcFM+N9pRpKdzwpy/4xpNf8psVi5meHk9aQhR1rZ18sr+Gj/fX4jZw1/JZLMwZmVbC5iONHKptG/CkqNE0KzOBSbGRnN/LUs1Kqa5kuGuoj6SCggJTWFgY7Gp0U3i4nn95fgtVzTYuW5DNd86ZOWq/yEvq2rjhT19Q2Wzr9trU5Fg6HC7q2zq5ctEUVhTksLeyha2ljdS02DljZjrL52Vy4pSkIY/QuW/1DlZvLmfjjy4Yk53Z7Z1OYiPDR3UEklJjjYhsMsYU9HucBoCBc7sNf/zoAL9+Zx85KbH8/qYlLMgZ/VRMY3sn28uaqGuzU9faSWxUOGfOTGdaWhytdiePf3SQJz4+iM3h6XiemhxLanwUO482YYzn+b9dOpcrF2b3eKM0xrCvqpWZGfFdJnKV1LVx+SOfcNH8LB6+YfGoXa9SanA0AATYhkP1/OKt3Ww50sjlC7N58NoFJMb0vwF3sFQ32yg62syJU5LItJZNqG2188HeGp7+7BA7y5s5Oz+dB646qctesm12Jz/+205e3VLO+XMz+f3XTiYuKoL6tk6ue+wzGts7Wf3dMwe8/6xSavRpAAiQAzWt/OLNPby7u4qspGh+cPFcrlsydVynGFxuw18/P8yv3tmH3eninNmZXLEwm7y0OH7w0jYO1rZx2YJs3tpRwYKcZB77+hL+5fkt7Chv4vl/Oo1Tpk2sIZZKTTQaAIap0+nmjx8e4PfvFxMdEcZ3zp3JbWfOIDYqvP83jxNVzTYe/+ggb2yv8PUppCdE88iNizljVjpriyq56/ktGAMOt5s/fG0Jl46RmbFKqd5pABiGraWN3PPyNvZVtXLFwmzuv/LEAU/aGo/cbsPmIw1sLW3kq4undBnFtKmknu+9sJXbz5zBrToDValxQQPAELjchj9+eICH1+0jMzGan159UlBXgVRKqaEYaAAYe+P4gqSq2ca/vrCVzw7UccXCbH5+7QKSxnAnr1JKDVfIBwBjDC8WlvKzN3bjcBn+67qF/ENBzrju5FVKqYEI6QBwsKaV+1bv4IuD9SydkcqD1y7ghIyEYFdLKaVGRUgGgEO1bTy6vpjVW8qJjwrnwWsXsKIglzBdP0YpFUJCJgAYY9ha2sjTnx3m9W1HiYoI49YzpvOdc2ZO6BE+SinVmwkdAOxOF/urWtla2siLhaVsL2siITqC28+awcqv6I1fKRXaRj0AiMglwG+BcODPxpgHA/0ZlU02bv3LBoqrW31LK+dnJvDAVSdyzZKcMbmImVJKjbZRvROKSDjwKHAhUAZsFJE1xphdgfyctIQopibHcv7cTOZPSWJ+dhIz0uN1ZI9SSvkZ7Z/CS4FiY8xBABFZBVwFBDQARIaH8eStpwbylEopNeGE9X9IQE0FSv2el1llSimlRtloB4B+ichKESkUkcKamppgV0cppSas0Q4A5UCu3/Mcq8zHGPO4MabAGFOQkZExqpVTSqlQMtoBYCOQLyIzRCQKuBFYM8p1UEopxSh3AhtjnCJyJ7AWzzDQp4wxRaNZB6WUUh6jPiDeGPMm8OZof65SSqmuxlwnsFJKqdGhAUAppULUmN4RTERqgJJhnCIdqA1QdcaLULxmCM3rDsVrhtC87sFe8zRjTL/DKMd0ABguESkcyLZoE0koXjOE5nWH4jVDaF73SF2zpoCUUipEaQBQSqkQNdEDwOPBrkAQhOI1Q2hedyheM4TmdY/INU/oPgCllFK9m+gtAKWUUr2YkAFARC4Rkb0iUiwi9wa7PiNFRHJFZL2I7BKRIhG52ypPFZF1IrLf+jcl2HUNNBEJF5EtIvJ36/kMEfnS+s5fsNaamlBEJFlEXhaRPSKyW0ROn+jftYj8q/Xf9k4ReV5EYibidy0iT4lItYjs9Cvr8bsVj0es698uIkuG+rkTLgD47Tp2KTAfuElE5ge3ViPGCXzfGDMfWAbcYV3rvcB7xph84D3r+URzN7Db7/kvgd8YY2YBDcDtQanVyPot8LYxZi6wCM/1T9jvWkSmAncBBcaYk/CsH3YjE/O7fhq45Liy3r7bS4F8628l8NhQP3TCBQD8dh0zxnQC3l3HJhxjTIUxZrP1uAXPDWEqnut9xjrsGeDq4NRwZIhIDnA58GfruQDnAy9bh0zEa54EfAV4EsAY02mMaWSCf9d41iuLFZEIIA6oYAJ+18aYj4D644p7+26vAp41Hl8AySKSPZTPnYgBICR3HROR6cDJwJdAljGmwnqpEsgKUrVGyn8D9wBu63ka0GiMcVrPJ+J3PgOoAf5ipb7+LCLxTODv2hhTDvwKOILnxt8EbGLif9devX23AbvHTcQAEHJEJAF4BfieMabZ/zXjGeY1YYZ6icgVQLUxZlOw6zLKIoAlwGPGmJOBNo5L90zA7zoFz6/dGcAUIJ7uaZKQMFLf7UQMAP3uOjaRiEgknpv/c8aYV63iKm+T0Pq3Olj1GwFnAl8VkcN40nvn48mNJ1tpApiY33kZUGaM+dJ6/jKegDCRv+sLgEPGmBpjjAN4Fc/3P9G/a6/evtuA3eMmYgAImV3HrNz3k8BuY8zDfi+tAW6xHt8CvDbadRspxpgfGmNyjDHT8Xy37xtjvg6sB663DptQ1wxgjKkESkVkjlW0HNjFBP6u8aR+lolInPXfuveaJ/R37ae373YNcLM1GmgZ0OSXKhocY8yE+wMuA/YBB4B/D3Z9RvA6z8LTLNwObLX+LsOTE38P2A+8C6QGu64jdP3nAn+3Hp8AbACKgZeA6GDXbwSudzFQaH3ffwNSJvp3DfwHsAfYCfwViJ6I3zXwPJ5+Dgee1t7tvX23gOAZ6XgA2IFnlNSQPldnAiulVIiaiCkgpZRSA6ABQCmlQpQGAKWUClEaAJRSKkRpAFBKqRClAUAppUKUBgCllApRGgCUUipE/X9TKKB3m8aCcwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(grads)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice how the gradient keeps on accumulating. The reason the gradient behaves this way is because of the update rule\n",
    "$$ c_t = c_{t-1} + i_t * g_t $$\n",
    "\n",
    "If you're familiar with gradient calculus, you'll see that the gradients for $ c_t $ propagate straight back to the gradients for $ c_{t-1} $. Therefore, the gradient of the initial timestep keeps increasing: since $ c_0 $ influences $ c_1 $, which in turn influences $ c_2 $, and so on, the influence of the initial state never disappears.\n",
    "\n",
    "Of course, this can be a mixed blessing: sometimes we don't want the current timestep to influence the hidden state 200 steps into the future. Sometimes, we want to \"forget\" the information we learned earlier and overwrite it with what we have newly learned. This is where the forget gate comes into play."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Turning the forget gate on"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The forget gate was originally proposed in the paper [Learning to Forget: Continual Prediction with LSTM](https://www.semanticscholar.org/paper/Learning-to-Forget%3A-Continual-Prediction-with-LSTM-Gers-Schmidhuber/11540131eae85b2e11d53df7f1360eeb6476e7f4). Let's see how the gradients change when we turn the forget gate on. Adhering to best practices, we'll initialize the bias for the forget gate to 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm.b_f.data = torch.ones_like(lstm.b_f.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate \n",
    "h_0, c_0 = (torch.zeros(hidden_size, requires_grad=True), \n",
    "            torch.zeros(hidden_size, requires_grad=True))\n",
    "grads = []\n",
    "h_t, c_t = h_0, c_0\n",
    "for t in range(100):\n",
    "    h_t, c_t = lstm_step(\n",
    "        test_embeddings[:, t, :], h_t, c_t,\n",
    "        lstm.W_ii, lstm.W_hi, lstm.b_i,\n",
    "        lstm.W_if, lstm.W_hf, lstm.b_f,\n",
    "        lstm.W_ig, lstm.W_hg, lstm.b_g,\n",
    "        lstm.W_io, lstm.W_ho, lstm.b_o,\n",
    "        use_forget_gate=True,\n",
    "    )\n",
    "    loss = h_t.abs().sum()\n",
    "    loss.backward(retain_graph=True)\n",
    "    grads.append(torch.norm(h_0.grad).item())\n",
    "    h_0.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1a7c134240>]"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJzt3Xd4nNWZ9/HvPRr13qxuSe5F7r3gBl5M7wmEUAIsKSQhhM2m7Sawm82bEAgkGwJLSwgh9ADGJAbb2IDBTbYluUnusiSrWbKKJavOef+YkSxjyZKlGY1m5v5cly5rnpnR3OPH/s3ReU4RYwxKKaU8n8XdBSillHIODXSllPISGuhKKeUlNNCVUspLaKArpZSX0EBXSikvoYGulFJeQgNdKaW8hAa6Ukp5CetgvlhcXJzJyMgYzJdUSimPt3379hPGmPjeHjeogZ6RkUF2dvZgvqRSSnk8ESnsy+O0y0UppbyEBrpSSnkJDXSllPISGuhKKeUlNNCVUspLaKArpZSX0EBXSikv4VOBbrMZXtl6jJrGFneXopRSTudTgb75cBU//vsuHn5vr7tLUUopp/OpQF+9pwyAt3eWsPVItZurUUop5/KZQLfZDB/sKWPxmHhSooL52bu7aWu3ubsspZRyGp8J9JziGsrrmrl2WjL/ccV48svq+dvWY+4uSymlnKbXQBeRIBHZKiK5IrJHRB52HP+ziBwRkRzH11TXl9t/H+wuw2oRlo1LYEVWIgtHxfHoBwVUnWp2d2lKKeUUfWmhNwPLjDFTgKnAChGZ67jvB8aYqY6vHJdVeYEOV57iyv/9lH2ldQAYY1i9p4z5o+KIDPZHRHjo6gmcam7jxU3nLmL269X5rN1bPthlK6XUgPQa6MbulOOmv+PLuLSqAdpQUMnukjq++dft1DW1kl9WT2FVIysmJnY+ZtSwcOZkxvJ+3nGMOfN2jpxo4KkNh3hyw0F3lK6UUv3Wpz50EfETkRygAlhjjNniuOt/RCRPRB4XkUCXVXmB8oprCAu0UnTyNP/+Rh6rd5chAssnJJz1uCsmJ3GosoGC8vrOY+/sLAEgp6iGivqmQa1bKaUGok+BboxpN8ZMBVKB2SKSBfwYGAfMAmKAH3b3XBG5V0SyRSS7srLSSWWfX15JLfNGxvKjFeNYvaeMpz8+xKz0GOLDz/7MWZGViEXg/bxSwN41805OCanRwRgD6/MrBqVepZRyhgsa5WKMqQHWAyuMMaWO7phm4E/A7B6e84wxZqYxZmZ8fK87KA1YfVMrhysbmJwSyT0XZbJiYiLNbTYuzUo857FxYYHMGxnL+3mlGGPYWVRDYVUj3714NClRwazZq4GulPIcfRnlEi8iUY7vg4HlQL6IJDmOCXAtsNuVhfbV7hL7hdBJqZGICL+5aTLfXTaKG2ekdvv4KyYlc/hEA/tK63lnZwmBVguXZSWyfEICGw9WcrqlfTDLV0qpfutLCz0JWC8iecA27H3oq4CXRWQXsAuIA37hujL7bldJDQCTU6MACA/y5/v/MpbIYP9uH3/pxAT8LMK7OSWsyivlkgkJhAf5c8n4BJpabWw8eGLQaldKqYHodZNoY0weMK2b48tcUtEA5RbXkhodTExoQJ8eHxsWyPyRsfzps6O0tNu4bmoKALMzYwgPtLJ2b/k5F1OVUmoo8rqZoruKa5mcGnlBz7liUhIt7TaiQ/xZNMbezx9gtbB4bDzr8stptw3pUZpKKQV4WaDXNLZwrLqRSSlRF/S8SycmEuBn4aopyQRYz/yVLJ+QwIlTLeQU1Ti7VKWUcrpeu1w8ya6SWoALbqFHhwbw3ncWkhIdfNbxJWOGYbUIa/aWMyM92ml1KqWUK3hVoOcV2wM9K+XCAh1gbGL4OcciQ/yZnh7NZ3phVCnlAbyqy2VXcS2ZcaE9jmjpj7mZMew5Xkt9U6vTfqZSSrmCVwV6XnENk/rROj+f2Zmx2AxkF5506s9VSiln85pAr6xv5nht0wX3n/dmenoUVouw5bDv7nBUVN3o7hKUUn3gFYFujOGN7UXAmQlFzhISYGVyaiRbjlQ59ed6iq1HqrnokfXsOV7r7lKUUr3w+EAvq23irj9v45HVBSwcFcfUNOcGOsCcEbHsKq6lsaXN6T97qOsI8vzS+l4e6VrGGFradMtApc7HowN9z/Fa/uXxj9l8uJqHr57IX+6afdY4cmeZnRlDm82wo9D3xqMfPdEAQGFVg1vreGtHCXN+uVbX1lHqPDw60Nftq6CuqY1/3H8Rd8zPwGIRl7zOzPRoLIJPdrscqbL3nxe6uR89p+gkJxtbyS+rc2sdSg1lHh3o5XVNRIf4kxkX6tLXCQ/yJysl0icvjHa0zAur3BvoHa+fX+berh+lhjKPDvSK+maGhQcNymvNyYwhp6iGplbf+ZW/td1G8cnTABzrQwvdGMPBinre2l5MRZ1zd3s66vhg6dgnVil1Lo+eKVpR38ywiMHZ+W52ZizPfnqEnKIa5o6IHZTXdLei6kbabYbRw8I4UHGKuqZWIoLOnbTV0mbjJ2/vYt2+ck422idg3TA9lce+NMUpdbS02ShxfLBooCvVM49uoVfWNZ2zrZyrzM6IQQSf6nbpaBUvdqxAeayHbpdnPz3Mm9uLWTJ2GI/cOJkVExP5cG8ZzW3O+W2m+GQjNgORwf7kl9aftam3UuoMjw10YwyVpwavyyUyxJ9xiRFsPeo7F0aPnrAH+JKxw4Du+9GLqhv5348OsGJiIo9/eSpfmpnGzbPTqG9q45P9zlkDp+N1l09IoL65rbMbSCl1No8N9JONrbS2G4YNUgsdYEZ6FLlFtT6zPvrRqgbCg6xMHW4f219YffbQRWMMD63cg0WEn101ofP4glFxRIf4syrv+FmPf/Hzo7ybU9KvOgAuc+wLu1e7XZTqlkcE+p8/O8K3Xt5+1rGKevtFt8HqQweYlhbNqeY2DlacGrTXdKcjJxrIjAslLNBKbGjAOV0uH+4tZ11+BQ9cMobkqDNLD/v7WViRlcjaveWdF5F3HjvJz1fu4f5Xc3h7Z/EF1VFY1Uh4oJV5I2MR0X50pXriEYFeXt/Mmr1n7xxUUdcMMGhdLkBnSzWnyDcW6jpa1UB6rH1IaHpsSGdLGaCptZ2HV+5hXGI4dy7IOOe5V05OpqGlnfX5Fdhshofe20t8eCBzR8Twb2/ksXZv+YXVERdCSICVjNhQt89aVWqo8ohAHx4TQmu7oazLULiK+o5AH7wWemasfWnence8f8Zox8iSzNgQANJjQ89qoW86VMXx2ib+fcVY/P3O/Wc0JzOGuLAAVuWV8vbOEnKLavjRinE8d8csspIj+NbfdrD5cN+uRxRWNXZ+sIxPCmefTi5SqlseEehp0fZQ6brqnzu6XCwWYWpalMcF+rOfHGbZoxu4+g8b+cqzm/nJ27s4XHn+bqMix8iSDMekreExIZTWNXWOXNlQUEGwvx/zR8Z1+3yrn4XLspJYl1/Or1bnMyUtiuumpRAWaOVPX5tNalQwP3l7V68jVtrabRRVN5Lh+GAZnxhBYVUjp5p9b10dpXrjEYE+PMb+n7nr5JaKumbCA62EBAzuUPppw6PYX1HvMYFS3dDCb9fsx88ixIQG0Nxm4+87irnktx/z3Vd2snp3GY9+UMCXnt7EZb/7lDrHRh4da7h0BHp6bAjGQFG1fYTJhv2VzBsZS5C/X4+vfeXkJJpabVTWN/PQVRM6l2aICQ3gG4tHcriygR29fDger2mizWa6tNAjACjQVrpS5+g1DUUkCPgECHQ8/k1jzM9FJBN4FYgFtgO3GWNaXFFkUlQQfhY5q4VeWd9M/CC2zjtMTYvCGMgrqmH+qO5bp0PJnz47wunWdv5463RGJ9i32TtxqplnPz3MS5sKWZl7HD+LMC4xnH2ldbyRXczdCzM54gj0zM4+dPufx6ob8LMIhVWN3L0w87yvPSsjhsy4UOZkxjBt+Nl7sl4+OYmfr9zDm9uLz7tfa0e/fUZHoCfbA31vaT0z0mMu9K9DKa/Wl+ZtM7DMGHNKRPyBjSLyT+D7wOPGmFdF5GngbuApVxTp72chKTLonC6Xwew/79CxPO9ODwj0uqZW/vz5UVZMTOwMc4C4sEB+fNl4vr5oJAfK68lKiSQ00MoNT33Oi58f5c75GRytaiAiyEpUiH1maLqjy+PoicbOceFLxgw77+tbLMIH31uEtZtF08ICrVw2KZFVucf52ZUTCA7ovqVf2Bno9tdPjgwiIsiqI12U6kavXS7GrqPD1d/xZYBlwJuO4y8C17qkQofhMSFnd7kM4jouXUWFBDAiPtQj+tFf2lRIfVMb9y0d1e39MaEBzBkRS2ig/XP9rgWZHKtu5KP8CgqrGsmMC0XEHsaxoQGEBvhxrLqRDQWVjIgLZbgjZM8nwGrpcRXMG2ekUt/cxod7y3p8/tGqRoL9/TpnBIsI45IiNNCV6kaf+tBFxE9EcoAKYA1wCKgxxnR0JBcDKT08914RyRaR7MrKyn4Xag90e/+tMYaKuma3tNDBPh49p+jkkJ6C3tjSxvMbj7B4TDyT+rgt36UTE0iODOKFjUc4cqKhs/8c7EE6PDaUgrJ6Nh+uYvHY+AHXODczltToYN7I7nlcemFVA+mxIZ0fLAATkiIoKKvH5iMTvJTqqz4FujGm3RgzFUgFZgPj+voCxphnjDEzjTEz4+P7HwJpMSGcONXM6ZZ2TjW3cbq1fVBHuHQ1dXgUJ061DOkp6K9sLaK6oYXvLOu+dd4dq5+F2+ZlsOlwFcUnT3f2W3fIiA1h85Eqmttsneu7DITFItw4I5XPDp2gpKb7v8ujVY3n1JGVEkljSzs5xUP/tySlBtMFjXIxxtQA64F5QJSIdPTBpwIXPqf7AqQ5RroUnWzsMgZ98LtcAKZ16UcfiowxvLTpKLMyopmZcWEXDm+ZnUaQv/2fRUbc2V0qwx0jXQKtFqetOHnD9FSMgb9vP7eV3m4zHKtqJP0LdVw6MYGwQCt/+fyoU2pQylv0GugiEi8iUY7vg4HlwD7swX6j42F3AO+6qkjoMnSxqrHLLFH3tNDHJYYT5G9h57GhOWM0p6iGo1WN3DQj7YKfGxUSwHXTUgHOaRmnx9hv9zZc8UKkxYSwYFQsT3986Jx1Xsrqmmhpt51TR3iQPzfOSOX9XaWd8xG6Msbwbk4J1/xhI7lD9ENXKVfoSws9CVgvInnANmCNMWYV8EPg+yJyEPvQxeddVyakRdvXCrG30Ad/UlFXVj8Lk1OjWLO3nKpTzW6p4Xze2VlCgNXCikmJ/Xr+/ReP5uuLR5CVcnbfe8dIkyVO6G7p6rGbpjIhOYL7X83hh2/mde4bWugYOpnezcXX2+el09pu+NuWY2cdP3qigdtf2Mr9r+aQW1zLIx/kO7VWpYayXoctGmPygGndHD+MvT99UMR0GWXR1m6/GBbvpi4XgAeXj+GOP23l1ue28Ld/nUtMaIDbaumqtd3GqrxSlo9P6HYzir5IjAzix5eNP+f4zIwYHlw+hhtmpA60zHNe75V/ncsTaw/w5IaDrMuvYM6ImM6Lnl9soQOMiA9jydh4Xt5yjG8tGUWA1cL7eaU88HoOgX4W/uuaiTQ0t/Pr1flsLzx53rHuSnkLj5gpCvZRFmkxIRRV21vogVYLEUHu23BpzohYnrt9FkdONPCVZzdzssElc6ou2MYDJ6hqaOHaad0OOhqQAKuF71w8mvB+flCcj9XPwr9dOpaX757DvJGx5Byr4Z+7y4gIspIY0f0H953zM6isb+afu0v525ZjfPuVHUxOiWTdg4u5fV4Gt89LJzrEnyfXH3R6vUoNRR61BV1aTAjHqhoJDbQyLCLwrKFs7rBwdBzP3j6Te/6SzV0vbuOtb8zvccz1YHl7ZwlRIf5OGYXiDvNHxXVO2CqtPU27zfT4d7podDyZcaE8/N5eqhtaWDo2nj/eOqNzklJooJW7F2by6If72V1Se04XklLexmNa6HBmclF5XZPbRrh80aIx8fzimix2Hqthzb6+Lwk7ENsLT3Z7MfCUY5LOlZOTCLB61KntVlJkMKnRPU9esliEO+dnUN3QwjVTk3nm9pnnzDi9fX4G4UFW/vCRttKV9/Oo//Vp0cGcbm0nv6zebSNcunP99BTSY0P4/boDLp9s1NZu46vPbeGGpz6nrPbsUP9gdxlNrTauc0F3y1B129x0Xrt3Lo9/aWq3y/hGBPlz5/wMVu8pY3+5rqOuvJtHBXrHVPOaxtYhFehWPwv3LR3FnuN1rNtX4dLXOnKigdOt7RRVn+bW5zZ3jrLJLarhyQ0HSYsJZvpw37kAaLEIc0bEnrer664FmQRYLby0qXAQK1Nq8HlWoMec+fV7WA8XytzlumkppMUE8/uPXNtKzy+ztzJ/ftUEik+e5rbnt/Lg67lc8+Rn1J1u5b+uznL7tYWhJjo0gEsnJrIy93jnlnhKeSOPCvSu/anxQ6iFDvYVIb+9dBR5xbVsKOj/mjW9yS+rw2oRvjJnOP932wwOVNTzXu5xvr54BOv/bQlLx51/BURfddOMVGpPt7J2kK5zKOUOHhXoQf5+nV0tQ6nLpcP101NJjQ7md+sOuOw18kvrGRkfRqDVjyVjh/HufQtZ9+BifnzZeJcMJ/QWC0bFkRQZxJvdLDGglLfwqECHM90uQ2WUS1f+fhbuXTSCnKIadhXXuuQ18svqGZt4Zm3zCckRnevcqJ75WYTrp6fwyf5KyuvOHSGklDfwuEDvCC93TfvvzTVTUwjyt/DqtmO9P/gC1TW1UlJzmnFJ4b0/WJ3jhump2Az8fYdL15FTym08LtCzUiKJCwsgJmRoTLX/oshgfy6flMS7OcdpbHHuvqMFjgui4xMjnPpzfcWI+DBmpkfz5vaiIb2WvVL95XGBfuf8DDb8YKnbZ2Sezy2zh3OquY3380qd+nPzHbv0aAu9/26ckcqhyoYhu/SxUgPhcYHuZxHCAof2igUz06MZER/Ka9uKnPpz95XVn3dtE9W7KyYnEezvx2tbnXtulBoKPC7QPYGIcPOsNLILT3LAibMTC8rqGZcUoePMByA8yJ+rpySzMvc4tadb3V2OUk6lge4i109Pxd9PnNZKt9kMBWX1jE/U7paB+urcdE63tvP2Dh3CqLyLBrqLxIUFsnxCAm/tKKalzTbgn1dSc5pTzW2M1QuiAzYpNZLJqZG8vOWYXhxVXkUD3YVunJHKycZWPj0w8Jmj+/SCqFN9dU46BypOsfVItbtLUcppNNBdaOGoeKJC/FmZe3zAP6tjDZexCRroznDVlGQigqz8dYvz5wso5S4a6C4UYLVwWVYSa/aWd+6T2V8FZfWkx4YQOsRH+HiK4AA/bpiRyurdpZzoZl/YqlPN1DbqRVPlWTTQXezqKck0trQPeFGofWV12jp3slvnDKe13fBG9tkXR40x3PrcFr7z6k43VaZU/2igu9jszBgSIgIH1O1SUnOaoycaGJ+kF0SdadSwcGZlRPPWjuKzLo7uLa0jv6yezw+eoK5JW+nKc/Qa6CKSJiLrRWSviOwRkfsdxx8SkRIRyXF8Xe76cj2Pn0W4YlIyHxdU9nvc8xNr9mP1s/ClWWlOrk5dNy2VgxWn2FVyZjG1jg/fNpth44ET7ipNqQvWlxZ6G/CgMWYCMBe4T0QmOO573Bgz1fH1D5dV6eGunppMS7uND3aX0dJm46kNh7jmyc+63Rf0iw6U1/PWjmJun5tOSlTwIFTrW66YZN9/tWPBLpvNsCq3lItGxxEZ7M9H+a7dgUopZ+o10I0xpcaYHY7v64F9gO9sWukEU1IjSY8N4YXPjnD57z/l16vzyS2q4Z2dva/695sPCggNsPKtpaMGoVLfExnizyXjh/Fe7nFa221sP3aSkprTXD89hUVj4tlQUIHNpmPVlWe4oD50EckApgFbHIe+LSJ5IvKCiPjORpYXSES4anIy+WX1NLe188KdM5mSFsW7OefvV99x7CQf7i3n3kUjiAkdmqtLeoPrpqVS1dDCJ/srWZlznECrheUTElk2Lp4Tp1rYfdw1a9sr5Wx9DnQRCQPeAr5njKkDngJGAlOBUuCxHp53r4hki0h2ZaXrtmYb6r6xZCSPf3kKax5YzLJxCVw9JZk9x+s4WHGq28cbY/j1P/OJCwvgroWZg1ytb1k8Jp6Y0ADeyC7mH7tKuWRCAmGBVhaNjkcE7XZRHqNPgS4i/tjD/GVjzN8BjDHlxph2Y4wNeBaY3d1zjTHPGGNmGmNmxsfHO6tujxMWaOW6aakE+fsBcNXkJETocfTL54eq2HKkmm8vHaVjz10swGrhqslJrN5TRlVDC1dPSQYgNiyQqWlRrNdAVx6iL6NcBHge2GeM+W2X40ldHnYdsNv55XmvYRFBzBsRy3u5x89ZT8QYwxNr95MQEcjNs4e7qULfcv30VADCg6wsGXum4bFs7DByi2uprD938pFSQ01fWugLgNuAZV8YoviIiOwSkTxgKfCAKwv1RtdMTebIiYazhswBbDpUxbajJ/nm4pGdLXrlWpNTI5mSFsVNM9IItJ75O186bhgAGwq0la6Gvl5/lzfGbAS6W4BbhykO0IqJSfznO3tYmXOcyalRncefWHeAYeHaOh9MIsK79y045/jE5AgSIgJ5a0cxV01J1g9YNaTpTFE3igzxZ/HYeN7LO067Y2jcpkNVbD1SzTe0dT4kiAhfXzSSzYerufoPGzv3dVVqKNJAd7Nrp6ZQXtfMxY9t4OH39vCr1fnEhwfylTnaOh8q7lqYyYt3zaa6oZWr/rCR17N1+zo1NGmgu9nlkxL59Q2TyIgL5eUtx8gtqtHW+RC0eEw8/7z/IqYPj+I/3t7dp1m+Sg02GcwdW2bOnGmys7MH7fU8TWNLG/tK65mWFoXFovuGDkVHTjSw7LENfGvJSH5w6Th3l6N8hIhsN8bM7O1x2kIfQkICrMxIj9YwH8Iy40JZMTGRlzYVcqq5zd3lKHUWDXSlLtC9i0ZQ19TGq1t1tyM1tGigK3WBpg2PZk5mDM9vPEJr+8A3AFfKWTTQleqHbyweSWltE+85Yb9YpZxFFwlRqh+WjI1nbEI4P1+5h+c3HiHI348xCWH88rpJ2FfLUGrwaQtdqX4QEf7rmoksGhNPYkQQp1vaeWVrEYcqu189U6nBoC10pfppzohY5oyIBaCoupGLHlnPZwerGDVMN/NW7qEtdKWcIC0mhLSYYD47qHuQKvfRQFfKSeaPiGPz4arOdXmUGmwa6Eo5yfxRsdQ1tbFHt6xTbqKBrpSTzB8ZB8BnB6vcXInyVRroSjlJfHggYxPC+fyQ9qMr99BAV8qJ5o+KZdvRaprb2t1divJBGuhKOdGCkXE0tdrYUVjj7lKUD9JAV8qJZo+IwSKwSbtdlBtooCvlRBFB/kxOjeKzQ3phVA0+DXSlnGzBqFhyimp0vXQ16HoNdBFJE5H1IrJXRPaIyP2O4zEiskZEDjj+jHZ9uUoNfbMyYmi3GfKKtR9dDa6+tNDbgAeNMROAucB9IjIB+BGwzhgzGljnuK2Uz5uaFgVATpEGuhpcvQa6MabUGLPD8X09sA9IAa4BXnQ87EXgWlcVqZQniQoJIDMulJxjPQf6qeY2Xvz8KE2tOrxROc8FrbYoIhnANGALkGCMKXXcVQYkOLUypTzYtLQoPj14AmPMOeuj155u5c4/bWXnsRqC/C18edZwN1WpvE2fL4qKSBjwFvA9Y0xd1/uMMQbodkUiEblXRLJFJLuysnJAxSrlKaYOj6KyvpnjtU1nHa9pbOGrz21hd0ktAVYLu0p03RflPH0KdBHxxx7mLxtj/u44XC4iSY77k4CK7p5rjHnGGDPTGDMzPj7eGTUrNeR19KPvPHay81hdUys3P7OZgvJ6/u+2GUwfHsWuYg105Tx9GeUiwPPAPmPMb7vctRK4w/H9HcC7zi9PKc80LjGCQKvlrH7017YWkV9WzzO3zWDZuAQmp0axr6yeljbdaFo5R19a6AuA24BlIpLj+Loc+BWwXEQOAJc4biulgACrhayUyM6RLsYYXtl2jBnp0SwZOwyASSmRtLTZ2F9e785SlRfp9aKoMWYj0NOutxc7txylvMfUtCj+urmQ1nYbOUU1HK5s4Dc3juy8f3JqJAC7SmrJSol0V5nKi+hMUaVcZNrwKJrbbOSX1vPK1mOEB1q5YnJS5/3DY0KICLKSp/3oykk00JVykY4Lo58cqOQfu0q5emoyIQFnfikWESalRrKrRCcgKefQQFfKRVKigokLC+SpDYdoarVxczfjzSelRFFQVq/rpyun0EBXykVEhKlpUZxqbmNicgSTUs/tJ5+cGklru6GgTC+MqoHTQFfKhaYNt3e73Dwrrdv7JzkuhvbUj26fs6dU31zQ1H+l1IW5anIyBWX1XDstpdv7U6ODiQ7xP2uC0eHKU6zbV8EnByrZdrSaSSmR/ODScczOjBmsspWH0kBXyoWGx4bw+1um9Xi/iJCVEkmeYwmAN7cX88O38mi3GUYPC+O6aams21fOl/5vE0vGxvPf12SRFhMyWOUrD6OBrpSbTU6N5OmPD/Pk+oP85oMCFo6K45EbJ5McFQzA6ZYJ/GXTUZ5Ye4BHPyzgdzf3/AGhfJsGulJuNiklinab4TcfFHD5pEQe//JUAq1+nfcHB/jx9cUj2V54UhfzUuelF0WVcrPp6VGEBVr5ypzh/O8t088K866yUiI5cqJBt7ZTPdIWulJuNiw8iO3/eUmPQd4hKyUCY2BfaR2zMvQCqTqXttCVGgJ6C3OArGT7EMfd2u2ieqCBrpSHGBYRRFxYILtL6np/sPJJGuhKeZCslAj2HNcWuuqeBrpSHiQrOZIDFad0c2nVLQ10pTxIVkoE7TZd+0V1TwNdKQ8ysePCaJdul8fX7OeJtfvdVZIaQnTYolIeJDU6mMhg/84Lo0dONPC/Hx3A6mfha/MziQzxd3OFyp20ha6UBxERJiafuTD6h48O4mcRWtpsvJtb4ubqlLtpoCvlYbJSIskvredgxSneySnh9nkZTEyO4LVtRe4uTbmZBrpSHmZicgQt7TYefCMXq0X4+qIRfHlWGnuO1+mkIx/Xa6CLyAsiUiEiu7vG32bxAAAQUUlEQVQce0hESkQkx/F1uWvLVEp16LgwmltUwy2zhzMsIohrpqQQYLXwRra20n1ZX1rofwZWdHP8cWPMVMfXP5xbllKqJ5lxoYQE+BFgtfDNJSMBiAzxZ8XERN7JOa5j1H1Yr4FujPkEqB6EWpRSfeBnEb4yezgPXDKGhIigzuNfnpVG7elWPtxb7sbqlDsNpA/92yKS5+iSiXZaRUqpXv3HlRM6W+cd5o2IJTU6mL9uLtS9SH1UfwP9KWAkMBUoBR7r6YEicq+IZItIdmVlZT9fTinVG4tFuHthJluPVLNGW+k+qV+BbowpN8a0G2NswLPA7PM89hljzExjzMz4+Pj+1qmU6oOvzk1n9LAw/vv9vdqX7oP6FegiktTl5nXA7p4eq5QaPP5+Fh66eiJF1ad59pPD7i5HDbK+DFt8BdgEjBWRYhG5G3hERHaJSB6wFHjAxXUqpfpowag4LstK5MkNBzlec9rd5ahB1JdRLrcYY5KMMf7GmFRjzPPGmNuMMZOMMZONMVcbY0oHo1ilVN/89IrxGAM/eDOXA+W6MqOv0JmiSnmh1OgQfnrFeLYeqWb5459w09Of64VSH6CBrpSXun1eBpt/fDE/uXwclfXN3PtSNvu1te7VNNCV8mKxYYHcu2gkb39rAWEBVh77sMDdJSkX0kBXygdEhwZwz0Uj+GBPOblFNe4uR7mIBrpSPuLuizKJCQ3gUW2ley0NdKV8RFiglW8uHsmnB06w+XCVu8tRLqCBrpQPuW1eOgkRgfzqn/nUNbW6uxzlZBroSvmQIH8/frhiHDlFNSx6ZD3PfHJIlwjwIhroSvmY66ensuo7C5mSGsUv/5HP8sc/5mRDi7vLUk6gga6UD8pKieTFu2bzp6/Noqj6NH/beszdJSkn0EBXyoctHTuMhaPi+Mumo7S229xdjhogDXSlfNxdCzMor2vmH7t0SSZPp4GulI9bMmYYmXGhvLDxiO505OE00JXycRaL8LUFGeQW17LjmH0WaUV9Ey9+fpT1BRXUntbhjZ7C6u4ClFLud8P0VH7zQQFPbTjI2MRwXth4lNOO4YwiMDYhnJ9dNYH5I+PcXKk6H22hK6UIDbRyy+zhrN1XwZPrD3HJhARWf+8i/nbPHB64ZAzNbTbufGEbq/KOu7tUdR7aQldKAfD1RSMwxnDttBQmJkd2Hp8/Ko7b56Vzz4vZfOeVnZyob+bOBZlurFT1RFvoSinAvtTuT6+YcFaYd4gKCeCv98zhkvEJPPTeXv66udANFareaKArpfokyN+Pp26dzqIx8fz3qr0crDjl7pLUF2igK6X6zOpn4dEbJxMS4McDr+XQ0qaTkYYSDXSl1AUZFhHE/7t+ErtKavn9ugPuLkd1oYGulLpgK7KSuHFGKn/ccJCVucd1QtIQ0Wugi8gLIlIhIru7HIsRkTUicsDxZ7Rry1RKDTU/v2oCYxLC+e4rO7n2yc/4ZH+lBrub9aWF/mdgxReO/QhYZ4wZDaxz3FZK+ZDwIH9WfWchj9wwmROnWrj9ha1c+sQnPLF2P/vL691dnk+SvnyiikgGsMoYk+W4XQAsMcaUikgSsMEYM7a3nzNz5kyTnZ09sIqVUkNOc1s7b2QXszLnONsKqzEG0mNDuGh0HBeNjmfBqDjCAnXaS3+JyHZjzMxeH9fPQK8xxkQ5vhfgZMftbp57L3AvwPDhw2cUFur4VaW8WUV9Ex/sKefjggo2HaqioaWdsEArN81M5c75GaTHhrq7RI8zaIHuuH3SGNNrP7q20JXyLS1tNrILq3ltWxHv55XSbgx3LcjkP6+c4O7SPEpfA72/o1zKHV0tOP6s6OfPUUp5sQCrhfkj4/jdzdP47EfLuH5aKs9vPMKWw1XuLs0r9TfQVwJ3OL6/A3jXOeUopbxVQkQQv7g2i5SoYB5+by/tNh0R42x9Gbb4CrAJGCsixSJyN/ArYLmIHAAucdxWSqnzCg7w48eXj2NvaR2vZxe5uxyv0+tlZ2PMLT3cdbGTa1FK+YArJiXxl4xCHv2ggMsnJREZ7O/ukryGzhRVSg0qEeFnV02gurGFRz8o0MlITqSBrpQadFkpkdw+N52XNhfyzb/uoLbxzDZ3LW02ahpb3Fid59KR/kopt/j5VRNJjQ7h16vzufz3n/K1BRlsO1rNxgMnsIjw+Y+XER6k3TEXQlvoSim3sFiEf100gre+OR8/i/CL9/eRV1zLwtFx1De3sb6g0t0lehxtoSul3GpKWhQffG8RZXVNZMSGYDMw55dr+XBPGVdPSXZ3eR5FW+hKKbcLDvAjMy4UEcHPIlwyPoENBZU0t7W7uzSPooGulBpyLp2YyKnmNj4/pDNKL4QGulJqyJk3MpbQAD8+3FPm7lI8iga6UmrICfL3Y8m4YazZW965REB5XRPPfXqYplbthumJBrpSakj6lwkJnDjVws5jJ6mob+KWZzbzi/f38diHBe4ubcjSQFdKDUlLxw3D3094dVsRtz67hdLaJpaMjee5jUf4/NCJPv2Mj/dXnjVpydtpoCulhqSIIH/mj4zjze3FHKtu5Pk7Z/LHW6eTERvKv72eS+1pe1A3tbaz9Ug1be22s56/Mvc4d7ywle+/nuOO8t1CA10pNWTdNDOVsEArz94+k/kj4wgJsPL4l6dSXt/Mg6/n8h/v7GL2/6zlS/+3ie+/ntvZ315Sc5qfvr2L8CAr6/IrWLu33M3vZHDoxCKl1JB15eRkVkxMxOp3pu05NS2K7y4bzeNr9xPkb2HFxERiQgN54bMjBPv78T/XZfH913Kw2Qwrv7OQb7y0nYdX7WHh6DiC/P3c+G5cTwNdKTWkdQ3zDt9eNooZ6dFMSYvsXO8lLNCP3390kLySWvaV1vHoTVMYGR/Gw9dM5CvPbuHpjw/xvUvGDHb5g0q7XJRSHsfPIiwcHXfW4l0PLB/DPQsz2VdaxxWTkrhhegoA80fGcdWUZP644RBHTzS4q+RB0adNop1FN4lWSrmSMYaP91cyOzOGkIAzHRBltU0s/+3HANy3bBR3zs/wqO4XV28SrZRSQ46IsGTssLPCHCAxMoh3vr2AOSNi+NU/87n4sY/ZeKBvQx89iQa6UsonjIwP47k7ZvG3e+YQHODHXS9u6/N4dk+hga6U8inzR8Xx+tfnkR4Twj0vZrPj2El3l+Q0GuhKKZ8TExrAy/fMIT48kDtf2MrK3OPkFtVQWnsam81z9zgd0EVRETkK1APtQFtvnfZ6UVQpNZQUn2zkS09v4nhtU+exi8cN4/k7Z7mxqnP19aKoM8ahLzXGeFdHlFLKJ6RGh7Dm+4s5UHGKyvpmVu8u460dxRysOMWoYWHuLu+CaZeLUsqnhQZamZoWxfIJCfzosnH4+wkvbyl0d1n9MtBAN8CHIrJdRO51RkFKKeUu8eGBrMhK4q3txZxu8bx11wca6AuNMdOBy4D7RGTRFx8gIveKSLaIZFdW6i7eSqmh7atzhlPX1MZ7ucfdXcoFG1CgG2NKHH9WAG8Ds7t5zDPGmJnGmJnx8fEDeTmllHK52ZkxjEkI46XNvXe71DS2cKyqcRCq6pt+XxQVkVDAYoypd3z/L8B/Oa0ypZRyAxHhq3PT+dm7e8gtqmFKWtRZ92cfrebVbUXsOHaSw5UNWC3Cqu8uZFxihJsqPmMgLfQEYKOI5AJbgfeNMaudU5ZSSrnPddNSCAnw49EPC8gpqsFmM1TUNfHAaznc+PQm1u4rZ0RcKD+4dCyhgVZ+sWofXxwC3vqFDTcGQ79b6MaYw8AUJ9ailFJDQniQP99cPJLfrt3PpwdOEBMaQEubjZY2G99eOopvLR3ZuV5MSIAfD7+3l4/yK7h4fAIAL20u5KGVe5iSGsmlExO5dGIiGXGhLq9bV1tUSqkeVDe08OmBSj4uqMRmDPdfMobMLwRza7uNS5/4BAys/t4i3ss9zoNv5DIzPZqmtnZ2l9QB8PRXp7MiK6lfdfR1YpEGulJKDdBH+eXc9edsLp2YwJq95cwbGcvzd8wiyN+PoupGPtxbzg3TU4gKCejXzx/MmaJKKeXTlo4dxkWj4/hgTzkz0qN59vaZneutp8WEcPfCzEGpQwNdKaUGSET45XWTeGlzIfctHXXOeuyDRQNdKaWcIC0mhJ9cPt6tNehaLkop5SU00JVSyktooCullJfQQFdKKS+hga6UUl5CA10ppbyEBrpSSnkJDXSllPISg7qWi4hUAv3drC8O8MXNqH3xffviewbffN+++J7hwt93ujGm1x2CBjXQB0JEsvuyOI238cX37YvvGXzzffviewbXvW/tclFKKS+hga6UUl7CkwL9GXcX4Ca++L598T2Db75vX3zP4KL37TF96Eoppc7Pk1roSimlzsMjAl1EVohIgYgcFJEfubseVxCRNBFZLyJ7RWSPiNzvOB4jImtE5IDjz2h31+psIuInIjtFZJXjdqaIbHGc79dEpH/7dg1hIhIlIm+KSL6I7BORed5+rkXkAce/7d0i8oqIBHnjuRaRF0SkQkR2dznW7bkVu9873n+eiEwfyGsP+UAXET/gSeAyYAJwi4hMcG9VLtEGPGiMmQDMBe5zvM8fAeuMMaOBdY7b3uZ+YF+X278GHjfGjAJOAne7pSrX+h2w2hgzDpiC/f177bkWkRTgu8BMY0wW4AfcjHee6z8DK75wrKdzexkw2vF1L/DUQF54yAc6MBs4aIw5bIxpAV4FrnFzTU5njCk1xuxwfF+P/T94Cvb3+qLjYS8C17qnQtcQkVTgCuA5x20BlgFvOh7ije85ElgEPA9gjGkxxtTg5eca+w5pwSJiBUKAUrzwXBtjPgGqv3C4p3N7DfAXY7cZiBKRpP6+ticEegpQ1OV2seOY1xKRDGAasAVIMMaUOu4qAxLcVJarPAH8O2Bz3I4FaowxbY7b3ni+M4FK4E+OrqbnRCQULz7XxpgS4FHgGPYgrwW24/3nukNP59ap+eYJge5TRCQMeAv4njGmrut9xj4kyWuGJYnIlUCFMWa7u2sZZFZgOvCUMWYa0MAXule88FxHY2+NZgLJQCjndkv4BFeeW08I9BIgrcvtVMcxryMi/tjD/GVjzN8dh8s7fgVz/FnhrvpcYAFwtYgcxd6Vtgx733KU49dy8M7zXQwUG2O2OG6/iT3gvflcXwIcMcZUGmNagb9jP//efq479HRunZpvnhDo24DRjqvhAdgvpKx0c01O5+g7fh7YZ4z5bZe7VgJ3OL6/A3h3sGtzFWPMj40xqcaYDOzn9SNjzK3AeuBGx8O86j0DGGPKgCIRGes4dDGwFy8+19i7WuaKSIjj33rHe/bqc91FT+d2JXC7Y7TLXKC2S9fMhTPGDPkv4HJgP3AI+Km763HRe1yI/dewPCDH8XU59j7ldcABYC0Q4+5aXfT+lwCrHN+PALYCB4E3gEB31+eC9zsVyHac73eAaG8/18DDQD6wG3gJCPTGcw28gv06QSv238bu7uncAoJ9FN8hYBf2UUD9fm2dKaqUUl7CE7pclFJK9YEGulJKeQkNdKWU8hIa6Eop5SU00JVSyktooCullJfQQFdKKS+hga6UUl7i/wMh8Gb6r/s6eQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(grads)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice how the gradients decay much more slowly than in the case of the Simple RNN. On the other hand, when we don't initialize the forget gate bias to 1... "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm.b_f.data = torch.zeros_like(lstm.b_f.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "h_0, c_0 = (torch.zeros(hidden_size, requires_grad=True), \n",
    "            torch.zeros(hidden_size, requires_grad=True))\n",
    "grads = []\n",
    "h_t, c_t = h_0, c_0\n",
    "for t in range(100):\n",
    "    h_t, c_t = lstm_step(\n",
    "        test_embeddings[:, t, :], h_t, c_t,\n",
    "        lstm.W_ii, lstm.W_hi, lstm.b_i,\n",
    "        lstm.W_if, lstm.W_hf, lstm.b_f,\n",
    "        lstm.W_ig, lstm.W_hg, lstm.b_g,\n",
    "        lstm.W_io, lstm.W_ho, lstm.b_o,\n",
    "        use_forget_gate=True,\n",
    "    )\n",
    "    loss = h_t.abs().sum()\n",
    "    loss.backward(retain_graph=True)\n",
    "    grads.append(torch.norm(h_0.grad).item())\n",
    "    h_0.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1a7c099208>]"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAFw1JREFUeJzt3X1sHPWdx/H317vr54eE2HGcB5JAHiDlIYBLQbSFg3IH9FQerlfgTm3UQ011Kjp66qniHqRrpXtopba0p2vRpUBJTy19AHpwFLWlKYXSFhqHhiQkkAcOShInNiSxHcfx4/f+2DGYxGtvvLsez8znJa12d3ac+Y4GPv75O7/ZMXdHRESiryzsAkREpDgU6CIiMaFAFxGJCQW6iEhMKNBFRGJCgS4iEhMKdBGRmFCgi4jEhAJdRCQm0tO5scbGRl+yZMl0blJEJPI2bdr0hrs3TbbetAb6kiVLaGtrm85NiohEnpm9ls96armIiMSEAl1EJCYU6CIiMaFAFxGJCQW6iEhMKNBFRGJCgS4iEhORCPQNOw7yjV/uDrsMEZEZLRKB/vTOTv7rqVfCLkNEZEaLRKBXlafpGxgOuwwRkRktEoFeU55iYHiEoeGRsEsREZmxIhHoVeUpAI4NapQuIpLLpIFuZpVm9jsze8HMXjSzzwfLl5rZc2a228y+b2blpSqyujz7HWLH+hXoIiK55DNC7weudPfzgdXANWZ2CfBF4C53XwYcBm4rVZE1FcEIfWCoVJsQEYm8SQPds44GbzPBw4ErgQeD5euBG0pSIVCVGQ10jdBFRHLJq4duZikz2wx0AE8Ae4Aj7j46ZN4LLChNiWNaLgp0EZGc8gp0dx9299XAQuBi4Kx8N2Bma82szczaOjs7p1TkWydF1XIREcnplGa5uPsR4EngUmCWmY3e8WghsC/Hz6xz91Z3b21qmvQOSuMa7aFrLrqISG75zHJpMrNZwesq4GpgB9lg/3Cw2hrgkVIVWZ3J/t7oVaCLiOSUzz1FW4D1ZpYi+wvgB+7+mJltB75nZv8C/B64t1RFjrZc+tRyERHJadJAd/ctwAXjLH+FbD+95N6etqgRuohILpG4UrQyrUAXEZlMJAK9rMyoyqQ0y0VEZAKRCHSA6vKURugiIhOITqBXpDRtUURkAtEJ9EyaXrVcRERyikygV6nlIiIyocgEeo1aLiIiE4pMoFdl0rpSVERkApEJ9OrylK4UFRGZQGQCvaZCPXQRkYlEJtCrMmn10EVEJhCZQK8uT9E7MIS7h12KiMiMFJlArypPMeLQPzQSdikiIjNSZAK9plw3uRARmUhkAn30vqK6WlREZHyRCfQqjdBFRCYUmUDXTS5ERCYWmUCvyqjlIiIykcgEerVaLiIiE4pcoKvlIiIyvugEekW25aIRuojI+KIT6JnsCF09dBGR8UUm0KvUchERmdCkgW5mi8zsSTPbbmYvmtkdwfLPmdk+M9scPK4rZaEV6TJSZaaWi4hIDuk81hkCPuPuz5tZHbDJzJ4IPrvL3b9UuvLeZmZUZ1JquYiI5DBpoLt7O9AevO4xsx3AglIXNp6qct2GTkQkl1PqoZvZEuAC4Llg0e1mtsXM7jOz2UWu7STVulG0iEhOeQe6mdUCDwGfdvdu4G7gTGA12RH8l3P83FozazOzts7OzoKKrS5Pc0wtFxGRceUV6GaWIRvm33H3hwHc/aC7D7v7CPBN4OLxftbd17l7q7u3NjU1FVSsRugiIrnlM8vFgHuBHe7+lTHLW8asdiOwrfjlvVOVAl1EJKd8ZrlcBnwU2Gpmm4Nl/wDcamarAQdeBT5ZkgrHqClP09HdX+rNiIhEUj6zXJ4BbJyPHi9+ORMbva+oiIicLDJXioKmLYqITCRSgV5TkVYPXUQkh0gFelUmRd/gMCMjHnYpIiIzTqQC/a2bXAxqlC4icqJIBrraLiIiJ4tYoGcn5ehqURGRk0Us0DVCFxHJJVKBrptciIjkFqlAr9F9RUVEcopUoFfpvqIiIjlFKtDfmraoEbqIyEkiFuijs1wU6CIiJ4pWoFeMnhRVy0VE5ETRCvSMZrmIiOQSqUBPp8ooT5Up0EVExhGpQIds20UtFxGRk0Uv0DO6DZ2IyHgiF+i6yYWIyPgiF+jV5Wm1XERExhHBQE/RqxG6iMhJIhnoarmIiJwsgoGulouIyHgiGOia5SIiMp5JA93MFpnZk2a23cxeNLM7guWnmdkTZrYreJ5d+nIV6CIiueQzQh8CPuPuq4BLgE+Z2SrgTmCDuy8HNgTvS64qaLm4+3RsTkQkMiYNdHdvd/fng9c9wA5gAXA9sD5YbT1wQ6mKHGvh7CoGh529h/umY3MiIpFxSj10M1sCXAA8BzS7e3vw0QGgOcfPrDWzNjNr6+zsLKDUrHMXNACwbV9Xwf+WiEic5B3oZlYLPAR82t27x37m2f7HuD0Qd1/n7q3u3trU1FRQsQAr59WRLjO2KtBFRN4hr0A3swzZMP+Ouz8cLD5oZi3B5y1AR2lKfKfKTIoVzXUKdBGRE+Qzy8WAe4Ed7v6VMR89CqwJXq8BHil+eeM7d0ED2/Z16cSoiMgY+YzQLwM+ClxpZpuDx3XAF4CrzWwX8IHg/bQ4Z2EDh48Nsu+IToyKiIxKT7aCuz8DWI6PrypuOfkZe2J04ezqMEoQEZlxInelKMBZOjEqInKSSAZ6ZSbF8uY6tu7rnnxlEZGEiGSgA5y7oF4nRkVExohwoDdwqHeA/V3Hwy5FRGRGiGygnxOcGN26V310ERGIcKCf3VJPqszYuu9I2KWIiMwIkQ30t68Y1YlRERGIcKCDToyKiIwV6UA/u6WeQ70DvHF0IOxSRERCF+lAb2moAuBgt2a6iIhEOtCb6ysABbqICEQ80Oc1VAJwsLs/5EpERMIX6UBvrK3ADA5ohC4iEu1Az6TKaKytoEOBLiIS7UCHbB9dI3QRkRgE+rz6SvXQRUSIQaA311dqlouICDEJ9EO9A/QPDYddiohIqCIf6PPqs1MXO9R2EZGEi3ygz9XFRSIiQAwCXRcXiYhkRT7Qm+uyga6piyKSdJMGupndZ2YdZrZtzLLPmdk+M9scPK4rbZm5zarOUJ4u08VFIpJ4+YzQ7weuGWf5Xe6+Ong8Xtyy8mdmurhIRIQ8At3dnwYOTUMtUzZPc9FFRArqod9uZluClszsolU0Bc26WlREZMqBfjdwJrAaaAe+nGtFM1trZm1m1tbZ2TnFzU2sub6SA13HdSs6EUm0KQW6ux9092F3HwG+CVw8wbrr3L3V3VubmpqmWueE5tVX0jc4TE//UEn+fRGRKJhSoJtZy5i3NwLbcq07Hd66uKhLfXQRSa70ZCuY2QPAFUCjme0F/hm4wsxWAw68CnyyhDVOavTy/4Pd/SxvrguzFBGR0Ewa6O5+6ziL7y1BLVPWXK+Li0REIn+lKLwd6Jq6KCJJFotArypPUV+ZVqCLSKLFItAh+yVdCnQRSbLYBHpzfSUHdHGRiCRYrAJdX9AlIkkWm0CfV19JR08/wyO6WlREkik2gd5cX8HwiPPmUbVdRCSZYhPoi+fUALC782jIlYiIhCM2gX52Sz0A2/d3h1yJiEg4YhPoTXUVzK2rYHu7Al1Ekik2gQ6wan49O9p7wi5DRCQU8Qr0lnp2d/QwMDQSdikiItMuXoE+v57BYWdXh0bpIpI88Qp0nRgVkQSLVaAvnlNDVSalE6MikkixCvRUmXFWS51G6CKSSLEKdMi2Xba3d+uG0SKSOPEL9Pn19BwfYu/hvrBLERGZVvEL9ODE6A710UUkYWIX6GfNq6fM0IlREUmc2AV6VXmKpY01OjEqIokTu0CH7Bd1aYQuIkkTy0BfNb+evYf76OobDLsUEZFpM2mgm9l9ZtZhZtvGLDvNzJ4ws13B8+zSlnlqztaJURFJoHxG6PcD15yw7E5gg7svBzYE72eMs+bVAbDroL7TRUSSY9JAd/engUMnLL4eWB+8Xg/cUOS6CjKvvpK6ijQ7D+ruRSKSHFPtoTe7e3vw+gDQXKR6isLMWDGvjpc1QheRBCn4pKhnr7HPeZ29ma01szYza+vs7Cx0c3lb0VzHzoM9+goAEUmMqQb6QTNrAQieO3Kt6O7r3L3V3VubmpqmuLlTt7K5liPHBuk82j9t2xQRCdNUA/1RYE3weg3wSHHKKZ4VzdkTozsPqI8uIsmQz7TFB4DfAivNbK+Z3QZ8AbjazHYBHwjezygrgpku6qOLSFKkJ1vB3W/N8dFVRa6lqBprK5hTU66piyKSGLG8UnTU8uZajdBFJDFiHegrm+vYdfCoZrqISCLEOtBXzKvjaP8Q+7uOh12KiEjJxTvQ35rporaLiMRfvAN9rma6iEhyxDrQG6ozzKuvZKcCXUQSINaBDtmZLgp0EUmC2Af66EyX4RHNdBGReIt9oK+YV0f/0AivHzoWdikiIiUV+0A/e1727kVb9nWFXImISGnFP9Bb6qivTPPrXW+EXYqISEnFPtDTqTIuW9bIr3Z16opREYm12Ac6wPuWN7G/6zh7OnvDLkVEpGQSEuiNAPxq1/TdMUlEZLolItAXnVbN0sYafqU+uojEWCICHeC9yxp59pU3GRgaCbsUEZGSSEygv295I8cGhnn+D4fDLkVEpCQSE+iXnjmHVJmpjy4isZWYQK+rzHDh6bPURxeR2EpMoEN2+uLWfV0c6h0IuxQRkaJLWKA34g7P7NYoXUTiJ1GBft7CWTRUZXh6p/roIhI/iQr0VJnxvuWNPLVTXwMgIvFTUKCb2atmttXMNptZW7GKKqUrVs6ls6ef7e3dYZciIlJUxRih/5G7r3b31iL8WyX3/uBrAJ5S20VEYiZRLReAufWVrGqp56mXFegiEi+FBroDPzOzTWa2drwVzGytmbWZWVtn58wI0ctXNrHptcP0HB8MuxQRkaIpNNDf6+4XAtcCnzKz95+4gruvc/dWd29tamoqcHPFcfmKJoZGnN/seTPsUkREiqagQHf3fcFzB/Aj4OJiFFVqFy2eTW1FWn10EYmVKQe6mdWYWd3oa+CPgW3FKqyUMqkyLls2h6de1vRFEYmPQkbozcAzZvYC8Dvgx+7+k+KUVXqXr5jLviN97Ok8GnYpIiJFkZ7qD7r7K8D5RaxlWl2+MtvPf2J7B8vm1oVcjYhI4RI3bXHUgllVXLR4Ng9uel1tFxGJhcQGOsDNrYvY09mrm16ISCwkOtCvO6+F6vIUP9i4N+xSREQKluhAr61I88FzW3hsy356+4fCLkdEpCCJDnSAm9+9iN6BYX68tT3sUkRECpL4QL9o8WzOaKrhh22vh12KiEhBEh/oZsZHWhex8dXDmpMuIpGW+EAHuOnCBaTLjM89+iLHB4fDLkdEZEoU6MDcukr+9cZzeGb3G3zi220KdRGJJAV64OZ3n84X/+w8ntn9Bn91/0b6BhTqIhItCvQxPtK6iC//+fk8+8qbfPahLWGXIyJyShToJ7jpwoXccdUK/veF/fzipYNhlyMikjcF+jj++oozWdFcyz/9aBtHdcGRiESEAn0c5eky/v2m82jvPs6Xfvpy2OWIiORFgZ7DRYtns+bSJaz/7atsek1f3iUiM58CfQJ/9ycrmd9Qxe3ffZ79R/rCLkdEZEIK9AnUVqT55sdaOXp8iI/e+xyHegfCLklEJCcF+iRWza/nnjWt7D3cx8fv36hvZRSRGUuBnof3nDGH//yLC9m2r4uPf2sjhzVSF5EZSIGep6tXNXPXzavZ/PoRbvzGr/VFXiIy4yjQT8GHzp/Pdz/xHnqOD3Hj13/Noy/sp/v4YNhliYgAYNN5g+TW1lZva2ubtu2VyuuHjnHb+o3sPHiUMoNzFzRw2bJGrju3hXfNr8fMwi5RRGLEzDa5e+uk6xUS6GZ2DfA1IAXc4+5fmGj9uAQ6QP/QMJteO8yzrxzi2T1vsukPhxkecRbPqebac1q4elUzFyyaRVmZwl1EClPyQDezFLATuBrYC2wEbnX37bl+Jk6BfqJDvQP87MUD/HhrO7/Z8ybDI05jbTmXLWukpaGKproKmusrWNlcx9LGGtIpdbtEJD/5Bnq6gG1cDOx291eCDX4PuB7IGehxdlpNObdcfDq3XHw6XccG+eXODn6+o4O2Vw/T2XOAgeGRt9YtT5dxRmMNs6vLqa1MU1+ZYVZ1hllV2ed0qoyUGWVlRm1FmoaqDA1VGTKpt0f7lZkUdZVpaivS+uUgIkBhgb4AGHsjzr3AeworJx4aqjNcv3oB169eAIC70903xN4jx3j5QA8vHehhT8dRuo8P8vqhY3T3DdLVN0jvFL+DPV1mlJlhxjufAYLfAWVmpMuM1Jh1jewt+MwI3p+wPNcGc3yQT3MpzPMLan5JmP7tpnN595LTSrqNQgI9L2a2FlgLcPrpp5d6czOSmdFQnaGhuoF3zW/IuV7/0DBdfYMMjzgjDsPDTk//IF3HsoE/HLTH3KFvcJie40P0HB9kcHiEEYeREcfJPo84ZN9ljYw4w+4Mj2Qf7uDBv+UE793HLBtfrhZdXo27aTj/7jjj/Sry6di4yASqMqmSb6OQQN8HLBrzfmGw7B3cfR2wDrI99AK2F3sV6RRz60p/0EUkngppvm4ElpvZUjMrB24BHi1OWSIicqqmPEJ39yEzux34Kdlpi/e5+4tFq0xERE5JQT10d38ceLxItYiISAE0301EJCYU6CIiMaFAFxGJCQW6iEhMKNBFRGJiWr8+18w6gdem+OONwBtFLCcqkrjfSdxnSOZ+J3Gf4dT3e7G7N0220rQGeiHMrC2fbxuLmyTudxL3GZK530ncZyjdfqvlIiISEwp0EZGYiFKgrwu7gJAkcb+TuM+QzP1O4j5DifY7Mj10ERGZWJRG6CIiMoFIBLqZXWNmL5vZbjO7M+x6SsHMFpnZk2a23cxeNLM7guWnmdkTZrYreJ4ddq3FZmYpM/u9mT0WvF9qZs8Fx/v7wdczx4qZzTKzB83sJTPbYWaXxv1Ym9nfBv9tbzOzB8ysMo7H2szuM7MOM9s2Ztm4x9ay/iPY/y1mdmEh257xgR7cjPrrwLXAKuBWM1sVblUlMQR8xt1XAZcAnwr2805gg7svBzYE7+PmDmDHmPdfBO5y92XAYeC2UKoqra8BP3H3s4Dzye5/bI+1mS0A/gZodfdzyH7l9i3E81jfD1xzwrJcx/ZaYHnwWAvcXciGZ3ygM+Zm1O4+AIzejDpW3L3d3Z8PXveQ/R98Adl9XR+sth64IZwKS8PMFgIfBO4J3htwJfBgsEoc97kBeD9wL4C7D7j7EWJ+rMl+XXeVmaWBaqCdGB5rd38aOHTC4lzH9nrg2571LDDLzFqmuu0oBPp4N6NeEFIt08LMlgAXAM8Bze7eHnx0AGgOqaxS+SrwWWAkeD8HOOLuQ8H7OB7vpUAn8K2g1XSPmdUQ42Pt7vuALwF/IBvkXcAm4n+sR+U6tkXNtygEeqKYWS3wEPBpd+8e+5lnpyTFZlqSmf0p0OHum8KuZZqlgQuBu939AqCXE9orMTzWs8mORpcC84EaTm5LJEIpj20UAj2vm1HHgZllyIb5d9z94WDxwdE/wYLnjrDqK4HLgA+Z2atkW2lXku0tzwr+LId4Hu+9wF53fy54/yDZgI/zsf4A8H/u3unug8DDZI9/3I/1qFzHtqj5FoVAT8TNqIPe8b3ADnf/ypiPHgXWBK/XAI9Md22l4u5/7+4L3X0J2eP6C3f/S+BJ4MPBarHaZwB3PwC8bmYrg0VXAduJ8bEm22q5xMyqg//WR/c51sd6jFzH9lHgY8Fsl0uArjGtmVPn7jP+AVwH7AT2AP8Ydj0l2sf3kv0zbAuwOXhcR7anvAHYBfwcOC3sWku0/1cAjwWvzwB+B+wGfghUhF1fCfZ3NdAWHO//AWbH/VgDnwdeArYB/w1UxPFYAw+QPU8wSPavsdtyHVvAyM7i2wNsJTsLaMrb1pWiIiIxEYWWi4iI5EGBLiISEwp0EZGYUKCLiMSEAl1EJCYU6CIiMaFAFxGJCQW6iEhM/D/jMPzFwb/8OgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(grads)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The gradient decays much more quickly now: this is why initializing the forget gate to 1 is a good idea, at least in the initial stages of training. \n",
    "\n",
    "Now, let's see what happens when we initalize the forget gate to -1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm.b_f.data = -torch.ones_like(lstm.b_f.data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "# generate \n",
    "h_0, c_0 = (torch.zeros(hidden_size, requires_grad=True), \n",
    "            torch.zeros(hidden_size, requires_grad=True))\n",
    "grads = []\n",
    "h_t, c_t = h_0, c_0\n",
    "for t in range(100):\n",
    "    h_t, c_t = lstm_step(\n",
    "        test_embeddings[:, t, :], h_t, c_t,\n",
    "        lstm.W_ii, lstm.W_hi, lstm.b_i,\n",
    "        lstm.W_if, lstm.W_hf, lstm.b_f,\n",
    "        lstm.W_ig, lstm.W_hg, lstm.b_g,\n",
    "        lstm.W_io, lstm.W_ho, lstm.b_o,\n",
    "        use_forget_gate=True,\n",
    "    )\n",
    "    loss = h_t.abs().sum()\n",
    "    loss.backward(retain_graph=True)\n",
    "    grads.append(torch.norm(h_0.grad).item())\n",
    "    h_0.grad.zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x1a7c2ea198>]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD8CAYAAABn919SAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAE/VJREFUeJzt3X2MZXV9x/H3d+48sewzOyyUp0VdsMQG0BHxiVJRi7Qp2JgqaSxpSNe0mmKDadD+UU2aVutT22htVkGxsahFVGKIliJCMYjOIiKw4gKCsFl2B2FhcWVnZ+bbP+4ZmN29d+buzNy5c859v5LJvffcM3u+Jwc+87vf+zvnRGYiSSq/nk4XIElaGAa6JFWEgS5JFWGgS1JFGOiSVBEGuiRVhIEuSRVhoEtSRRjoklQRvYu5sXXr1uWGDRsWc5OSVHpbtmx5IjOHZltvUQN9w4YNjIyMLOYmJan0IuKRVtaz5SJJFWGgS1JFGOiSVBEGuiRVhIEuSRVhoEtSRRjoklQRpQj0m7bu5N+/90Cny5CkJa0UgX7Lz0fZfOtDnS5Dkpa0UgR6f62HsfHJTpchSUtaOQK910CXpNmUJtDHJ5PJyex0KZK0ZM0a6BExGBE/jIifRMS9EfGhYvnJEXFHRDwQEV+JiP52FdnfWy9zbMJRuiQ108oIfR/whsw8HTgDOD8izgY+AnwyM18CPAVc2q4i+2v1MvfZdpGkpmYN9Kx7tnjZV/wk8Abg2mL51cBFbakQGJgaoRvoktRUSz30iKhFxF3ALuBG4EFgd2aOF6s8BhzXnhJtuUhSK1oK9MycyMwzgOOBs4CXtrqBiNgUESMRMTI6OjqnIvsdoUvSrA5rlktm7gZuBl4NrI6IqTseHQ9sb/I7mzNzODOHh4ZmvYNSQ301A12SZtPKLJehiFhdPD8CeBOwlXqwv61Y7RLgm+0qst9Al6RZtXJP0WOBqyOiRv0PwFcz81sRcR/w5Yj4B+DHwJXtKvKFHvpEuzYhSaU3a6Bn5t3AmQ2WP0S9n952U4HutEVJaq4UZ4o6bVGSZleKQO+v1QADXZJmUo5AL0bo+ye8loskNVOqQPdLUUlqrlyBbstFkpoqR6A7D12SZlWOQHfaoiTNqhSBPuDFuSRpVqUIdFsukjS7UgR6T0/Q2xMGuiTNoBSBDt4oWpJmU65At4cuSU2VJ9BrjtAlaSblCXRbLpI0o1IF+j5bLpLUVHkC3ZaLJM2oNIE+YMtFkmZUmkC3hy5JMytNoPfVnLYoSTMpTaA7QpekmZUn0P1SVJJmVJ5A90xRSZpRuQLdEbokNTVroEfECRFxc0TcFxH3RsRlxfIPRsT2iLir+LmgnYUO9PZ4gwtJmkFvC+uMA5dn5p0RsQLYEhE3Fu99MjM/1r7yXlDvoXuTaElqZtZAz8wdwI7i+Z6I2Aoc1+7CDmYPXZJmdlg99IjYAJwJ3FEsek9E3B0RV0XEmgWu7QD9vT3sn8h2bkKSSq3lQI+I5cDXgPdm5jPAZ4AXA2dQH8F/vMnvbYqIkYgYGR0dnXOh/bUaE5PJxKShLkmNtBToEdFHPcy/lJnXAWTmzsycyMxJ4LPAWY1+NzM3Z+ZwZg4PDQ3NudD+Xu8rKkkzaWWWSwBXAlsz8xPTlh87bbW3AvcsfHkvMNAlaWatzHJ5LfBO4KcRcVex7APAxRFxBpDAw8C72lJhYSrQ901MAH3t3JQklVIrs1xuA6LBWzcsfDnNDdQcoUvSTEp1pigY6JLUTPkC3bnoktRQeQLdloskzag8gW7LRZJmZKBLUkWULtD32UOXpIbKE+j20CVpRqUJ9AFbLpI0o9IEep8jdEmaUWkC3XnokjSz8gW6I3RJashAl6SKKE+g12y5SNJMShfo+xyhS1JDpQn0np6grxa2XCSpidIEOtRH6Qa6JDVWrkDv7WFsYqLTZUjSklS6QN8/np0uQ5KWpNIFurNcJKmxcgW6PXRJaqpcgd5bc9qiJDVRskC35SJJzZQq0AdqPYyNO8tFkhqZNdAj4oSIuDki7ouIeyPismL52oi4MSK2FY9r2l1sf689dElqppUR+jhweWaeBpwNvDsiTgOuAG7KzI3ATcXrtrLlIknNzRrombkjM+8snu8BtgLHARcCVxerXQ1c1K4ipzjLRZKaO6weekRsAM4E7gDWZ+aO4q3HgfVNfmdTRIxExMjo6Og8SrXlIkkzaTnQI2I58DXgvZn5zPT3MjOBhqdwZubmzBzOzOGhoaF5FWugS1JzLQV6RPRRD/MvZeZ1xeKdEXFs8f6xwK72lPgCe+iS1Fwrs1wCuBLYmpmfmPbW9cAlxfNLgG8ufHkH6q/1eGKRJDXR28I6rwXeCfw0Iu4qln0A+DDw1Yi4FHgE+JP2lPgCWy6S1NysgZ6ZtwHR5O3zFracmfXX6i2XzKT+wUGSNKVUZ4r29/aQCeOTXkJXkg5WukAHbLtIUgPlCvSagS5JzZQr0KdG6E5dlKRDlDPQHaFL0iFKFegDRaA7F12SDlWqQLeHLknNlSvQ7aFLUlPlDHRH6JJ0iHIFetFy2e8IXZIOUa5Ad4QuSU2VMtCd5SJJhypVoA/4pagkNVWqQO+v1QBbLpLUSLkC3R66JDVV0kCf6HAlkrT0lDPQ7aFL0iHKFeie+i9JTZUq0Ptq9dvOGeiSdKhSBXpE0N/bwz5bLpJ0iFIFOhQ3inaELkmHKF+g9xroktTIrIEeEVdFxK6IuGfasg9GxPaIuKv4uaC9Zb7AEbokNdbKCP0LwPkNln8yM88ofm5Y2LKa6+/tcdqiJDUwa6Bn5q3Ak4tQS0tsuUhSY/Ppob8nIu4uWjJrFqyiWdhykaTG5hronwFeDJwB7AA+3mzFiNgUESMRMTI6OjrHzb3AloskNTanQM/MnZk5kZmTwGeBs2ZYd3NmDmfm8NDQ0FzrfF5/b4/XQ5ekBuYU6BFx7LSXbwXuabbuQhuwhy5JDfXOtkJEXAOcC6yLiMeAvwfOjYgzgAQeBt7VxhoPYA9dkhqbNdAz8+IGi69sQy0tsYcuSY15pqgkVUT5Ar3Ww35H6JJ0iPIFuiN0SWrIQJekiihloHs9dEk6VOkC/Yi+GmPjk/bRJekgpQv0dcsHAHjy12MdrkSSlpbSBfrQinqg73pmX4crkaSlpbSBPvrscx2uRJKWltIF+tFTgb7HEbokTVe6QJ/qoRvoknSg0gX6YF+NlYO9BrokHaR0gQ71Pvroswa6JE1X2kB3loskHaiUgX70ikFH6JJ0kFIG+tCKAXvoknSQ0gb63rEJfr1vvNOlSNKSUc5Ad+qiJB2inIE+dfq/gS5JzytloB+90hG6JB2slIH+QsvF67lI0pRSBvqaZf3UesKpi5I0TSkDvacnWLe835aLJE0za6BHxFURsSsi7pm2bG1E3BgR24rHNe0t81DORZekA7UyQv8CcP5By64AbsrMjcBNxetFdfSKQWe5SNI0swZ6Zt4KPHnQ4guBq4vnVwMXLXBdsxpa7ghdkqabaw99fWbuKJ4/DqxfoHpaNrRigF/9eoyJyVzsTUvSkjTvL0UzM4GmqRoRmyJiJCJGRkdH57u55w2tGGBiMnlqrzeLliSYe6DvjIhjAYrHXc1WzMzNmTmcmcNDQ0Nz3NyhhrwVnSQdYK6Bfj1wSfH8EuCbC1NO67y3qCQdqJVpi9cAtwOnRsRjEXEp8GHgTRGxDXhj8XpReT0XSTpQ72wrZObFTd46b4FrOSzeLFqSDlTKM0UBjhzo5cj+moEuSYXSBjp4s2hJmq78ge4VFyUJKHmgH71i0JaLJBVKHehDKwac5SJJhdIH+p7nxnlu/0SnS5Gkjit9oAPsesZRuiSVOtCPWTkIwOPP+MWoJJU70FcZ6JI0pdSBvr4Yoe8y0CWp3IG+crCXwb4eHn/aQJekUgd6RHDMykFbLpJEyQMd6m0XZ7lIUgUC/ZhVjtAlCSoQ6OuLlkv9TniS1L0qEehj45Ps3ru/06VIUkeVPtCnTi7a6VUXJXW50gf6+pX10/+duiip21Ug0IsRul+MSupylQn0x5926qKk7lb6QO/v7eGoI/vtoUvqeqUPdICjVw6y0x66pC5XiUA/ZuWAJxdJ6nq98/nliHgY2ANMAOOZObwQRR2uY1YN8tPtT3di05K0ZMwr0Au/l5lPLMC/M2frVw7yxLNj7J+YpK9WiQ8dknTYKpF+z18X3RtGS+pi8w30BP4nIrZExKZGK0TEpogYiYiR0dHReW6usedvRecXo5K62HwD/XWZ+XLgLcC7I+Kcg1fIzM2ZOZyZw0NDQ/PcXGPeuUiS5hnombm9eNwFfB04ayGKOlzeW1SS5hHoEXFkRKyYeg68GbhnoQo7HGuW9dFf6zHQJXW1+cxyWQ98PSKm/p3/ysxvL0hVhykiOHrlgCcXSepqcw70zHwIOH0Ba5mXY1YOstNb0UnqYpWYtgiwftWgV1yU1NWqE+grvBWdpO5WmUA/ZtUAe8cm2LNvvNOlSFJHVCbQT1izDIBtO5/tcCWS1BmVCfSzX3QUEfB/29pzNqokLXWVCfQ1R/Zz+vGrueXnBrqk7lSZQAc455QhfvLobnbvHet0KZK06CoV6L97yhCTCbc90NGr+UpSR1Qq0E8/fhUrB3u55X7bLpK6T6UCvbfWw+s3DnHrtlHno0vqOpUKdKi3XXY+s4/7d+7pdCmStKgqF+ivP2UdALc620VSl6lcoB+76ghOXb/C6YuSuk7lAh3gnFPW8aNfPMXeMS8DIKl7VDLQz/vt9YxNTHLdnds7XYokLZpKBvqrTl7Ly09czae++wDP7Z/odDmStCgqGegRwfvefCqPP/Mc1/zwl50uR5IWRSUDHeA1L1nHq190FJ+++UF76ZK6QmUDHeDyN5/CE8/u44u3P9LpUiSp7Sod6MMb1nLuqUP8xy0P8vTe/Z0uR5LaqtKBDvC+N5/KnufGefvm29m++zedLkeS2qbygf6y41bxhT9/Jdt3/4YLP/V9fvLo7k6XJEltMa9Aj4jzI+L+iHggIq5YqKIW2us3DnHdX76GI/p7ePvm2/n0zQ/4RamkyplzoEdEDfg08BbgNODiiDhtoQpbaBvXr+Abf/VaXveSdXz0O/dzzj9/j89//xf21iVVRsz1MrMR8Wrgg5n5+8Xr9wNk5j81+53h4eEcGRmZ0/YW0pZHnuSj37mfHzz0JBHw0mNW8qqT17LhqGWsObKf1cv6WbOsjzXL+lm9rI/lA71ERKfLltSlImJLZg7Ptl7vPLZxHPDotNePAa+ax7+3aF5x0lqu+YuzufOXu7lt2xPc8Ytf8eUf/ZLn9k82/Z2B3h4G+2r01YKeCGo99ceeHuiJIKif0PR87E/L/+l/Clr5w+CfDql6/vGPf4dXbljb1m3MJ9BbEhGbgE0AJ554Yrs317KI4BUnreEVJ60BNjIxmezeO8ZTe/c///jU3jF27x3j2X0T7Ns/wXP7JxibSDKTiclkIhMSJjOZrD8FOODmGgd8/mnhw1C2spKk0jmir9b2bcwn0LcDJ0x7fXyx7ACZuRnYDPWWyzy211a1nuCo5QMctXyg06VI0pzMZ5bLj4CNEXFyRPQD7wCuX5iyJEmHa84j9Mwcj4j3AN8BasBVmXnvglUmSTos8+qhZ+YNwA0LVIskaR4qf6aoJHULA12SKsJAl6SKMNAlqSIMdEmqiDlfy2VOG4sYBeZ6+6B1wBMLWE5ZdON+d+M+Q3fudzfuMxz+fp+UmUOzrbSogT4fETHSysVpqqYb97sb9xm6c7+7cZ+hfftty0WSKsJAl6SKKFOgb+50AR3SjfvdjfsM3bnf3bjP0Kb9Lk0PXZI0szKN0CVJMyhFoJflZtTzEREnRMTNEXFfRNwbEZcVy9dGxI0Rsa14XNPpWhdaRNQi4scR8a3i9ckRcUdxvL9SXJ65UiJidURcGxE/i4itEfHqqh/riPib4r/teyLimogYrOKxjoirImJXRNwzbVnDYxt1/1bs/90R8fL5bHvJB3rZbkY9D+PA5Zl5GnA28O5iP68AbsrMjcBNxeuquQzYOu31R4BPZuZLgKeASztSVXv9K/DtzHwpcDr1/a/ssY6I44C/BoYz82XUL7n9Dqp5rL8AnH/QsmbH9i3AxuJnE/CZ+Wx4yQc6cBbwQGY+lJljwJeBCztc04LLzB2ZeWfxfA/1/8GPo76vVxerXQ1c1JkK2yMijgf+APhc8TqANwDXFqtUcZ9XAecAVwJk5lhm7qbix5r65bqPiIheYBmwgwoe68y8FXjyoMXNju2FwBez7gfA6og4dq7bLkOgN7oZ9XEdqmVRRMQG4EzgDmB9Zu4o3nocWN+hstrlX4C/Babu0H0UsDszx4vXVTzeJwOjwOeLVtPnIuJIKnysM3M78DHgl9SD/GlgC9U/1lOaHdsFzbcyBHpXiYjlwNeA92bmM9Pfy/qUpMpMS4qIPwR2ZeaWTteyyHqBlwOfycwzgV9zUHulgsd6DfXR6MnAbwFHcmhboiu089iWIdBbuhl1FUREH/Uw/1JmXlcs3jn1Eax43NWp+trgtcAfRcTD1Ftpb6DeW15dfCyHah7vx4DHMvOO4vW11AO+ysf6jcAvMnM0M/cD11E//lU/1lOaHdsFzbcyBHpX3Iy66B1fCWzNzE9Me+t64JLi+SXANxe7tnbJzPdn5vGZuYH6cf1uZv4pcDPwtmK1Su0zQGY+DjwaEacWi84D7qPCx5p6q+XsiFhW/Lc+tc+VPtbTNDu21wN/Vsx2ORt4elpr5vBl5pL/AS4Afg48CPxdp+tp0z6+jvrHsLuBu4qfC6j3lG8CtgH/C6ztdK1t2v9zgW8Vz18E/BB4APhvYKDT9bVhf88ARorj/Q1gTdWPNfAh4GfAPcB/AgNVPNbANdS/J9hP/dPYpc2OLRDUZ/E9CPyU+iygOW/bM0UlqSLK0HKRJLXAQJekijDQJakiDHRJqggDXZIqwkCXpIow0CWpIgx0SaqI/we/pxwkk/9UdgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.plot(grads)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The weights decay even faster now."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We looked at a lot of charts, but the most important point is that the LSTM basically has control over how much of the gradient to allow to flow through each timestep. This is what makes them so easy to train."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Making our LSTM Faster"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Remember how slow our implementation of the LSTM was slow? Let's see how we can speed it up."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you look at the code for our LSTM carefully, you'll notice that there is a lot of shared processing that could be batched together. For instance, the input and forget gates are both computed based on a linear transformation of the input and the hidden states.\n",
    "\n",
    "\n",
    "We can group these computations into just two matrix multiplications. The code now looks like this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "class OptimizedLSTM(nn.Module):\n",
    "    def __init__(self, input_sz: int, hidden_sz: int):\n",
    "        super().__init__()\n",
    "        self.input_sz = input_sz\n",
    "        self.hidden_size = hidden_sz\n",
    "        self.weight_ih = Parameter(torch.Tensor(input_sz, hidden_sz * 4))\n",
    "        self.weight_hh = Parameter(torch.Tensor(hidden_sz, hidden_sz * 4))\n",
    "        self.bias = Parameter(torch.Tensor(hidden_sz * 4))\n",
    "        self.init_weights()\n",
    "    \n",
    "    def init_weights(self):\n",
    "        for p in self.parameters():\n",
    "            if p.data.ndimension() >= 2:\n",
    "                nn.init.xavier_uniform_(p.data)\n",
    "            else:\n",
    "                nn.init.zeros_(p.data)\n",
    "        \n",
    "    def forward(self, x: torch.Tensor, \n",
    "                init_states: Optional[Tuple[torch.Tensor]]=None\n",
    "               ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:\n",
    "        \"\"\"Assumes x is of shape (batch, sequence, feature)\"\"\"\n",
    "        bs, seq_sz, _ = x.size()\n",
    "        hidden_seq = []\n",
    "        if init_states is None:\n",
    "            h_t, c_t = (torch.zeros(self.hidden_size).to(x.device), \n",
    "                        torch.zeros(self.hidden_size).to(x.device))\n",
    "        else:\n",
    "            h_t, c_t = init_states\n",
    "        \n",
    "        HS = self.hidden_size\n",
    "        for t in range(seq_sz):\n",
    "            x_t = x[:, t, :]\n",
    "            # batch the computations into a single matrix multiplication\n",
    "            gates = x_t @ self.weight_ih + h_t @ self.weight_hh + self.bias\n",
    "            i_t, f_t, g_t, o_t = (\n",
    "                torch.sigmoid(gates[:, :HS]), # input\n",
    "                torch.sigmoid(gates[:, HS:HS*2]), # forget\n",
    "                torch.tanh(gates[:, HS*2:HS*3]),\n",
    "                torch.sigmoid(gates[:, HS*3:]), # output\n",
    "            )\n",
    "            c_t = f_t * c_t + i_t * g_t\n",
    "            h_t = o_t * torch.tanh(c_t)\n",
    "            hidden_seq.append(h_t.unsqueeze(Dim.batch))\n",
    "        hidden_seq = torch.cat(hidden_seq, dim=Dim.batch)\n",
    "        # reshape from shape (sequence, batch, feature) to (batch, sequence, feature)\n",
    "        hidden_seq = hidden_seq.transpose(Dim.batch, Dim.seq).contiguous()\n",
    "        return hidden_seq, (h_t, c_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "lstm = OptimizedLSTM(100, 32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = torch.arange(5 * 10 * 100).view((5, 10, 100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "hs, _ = lstm(a.float())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([5, 10, 32])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hs.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's see how the training speed changes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "02/15/2019 16:53:36 - WARNING - allennlp.training.trainer -   You provided a validation dataset but patience was set to None, meaning that early stopping is disabled\n",
      "02/15/2019 16:53:36 - INFO - allennlp.training.trainer -   Beginning training.\n",
      "02/15/2019 16:53:36 - INFO - allennlp.training.trainer -   Epoch 0/0\n",
      "02/15/2019 16:53:36 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1905.47968\n",
      "02/15/2019 16:53:36 - INFO - allennlp.training.trainer -   Training\n",
      "loss: 2.7769 ||: 100%|██████████| 338/338 [13:52<00:00,  2.02s/it]\n",
      "02/15/2019 17:07:28 - INFO - allennlp.training.trainer -   Validating\n",
      "loss: 2.4106 ||: 100%|██████████| 38/38 [00:06<00:00,  6.66it/s]\n",
      "02/15/2019 17:07:35 - INFO - allennlp.training.trainer -                     Training |  Validation\n",
      "02/15/2019 17:07:35 - INFO - allennlp.training.trainer -   cpu_memory_MB |  1905.480  |       N/A\n",
      "02/15/2019 17:07:35 - INFO - allennlp.training.trainer -   loss          |     2.777  |     2.411\n",
      "02/15/2019 17:07:35 - INFO - allennlp.training.trainer -   Epoch duration: 00:13:58\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'peak_cpu_memory_MB': 1905.47968,\n",
       " 'training_duration': '00:13:58',\n",
       " 'training_start_epoch': 0,\n",
       " 'training_epochs': 0,\n",
       " 'epoch': 0,\n",
       " 'training_loss': 2.7768504443253286,\n",
       " 'training_cpu_memory_MB': 1905.47968,\n",
       " 'validation_loss': 2.4106190016395166,\n",
       " 'best_epoch': 0,\n",
       " 'best_validation_loss': 2.4106190016395166}"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lm_optimized = LanguageModel(OptimizedLSTM(50, 125), vocab)\n",
    "train(lm_optimized, epochs=N_EPOCHS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model is faster now, but still not quite as fast as we might want it to be. To really make our LSTM fast, we'll need to pass it over to CuDNN. But that's a topic for another post/notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
